From e12af460eda05120f78d8066dd06d9ec1e33055a Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 26 Feb 2025 09:20:52 -0800 Subject: [PATCH 01/20] Add cuda Blackwell architecture for v12 (#9350) * Add cuda Blackwell architecture for v12 * Win: Split rocm out to separate zip file * Reduce CC matrix The 6.2 and 7.2 architectures only appear on Jetsons, so they were wasting space. The 5.0 should be forward compatible with 5.2 and 5.3. --- CMakePresets.json | 4 ++-- scripts/build_windows.ps1 | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index 68546bde..6b654533 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -21,14 +21,14 @@ "name": "CUDA 11", "inherits": [ "CUDA" ], "cacheVariables": { - "CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;62;70;72;75;80;86" + "CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86" } }, { "name": "CUDA 12", "inherits": [ "CUDA" ], "cacheVariables": { - "CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;62;70;72;75;80;86;87;89;90;90a" + "CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;100" } }, { diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 465cc551..312c3db5 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -222,13 +222,26 @@ function buildInstaller() { function distZip() { if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64") { + if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm") { + write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" + # Temporarily adjust paths so we can retain the same directory structure + Remove-Item -ea 0 -r "${script:SRC_DIR}\dist\windows-amd64-rocm" + mkdir -Force -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" + Write-Output "Extract this ROCm zip file to the same location where you extracted ollama-windows-amd64.zip" > "${script:SRC_DIR}\dist\windows-amd64-rocm\README.txt" + Move-Item -path "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -destination "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" + Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64-rocm\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64-rocm.zip" -Force + } + write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip" - Compress-Archive -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force + Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force + if (Test-Path -Path "${script:SRC_DIR}\dist\windows-amd64-rocm") { + Move-Item -destination "${script:SRC_DIR}\dist\windows-amd64\lib\ollama\rocm" -path "${script:SRC_DIR}\dist\windows-amd64-rocm\lib\ollama" + } } if (Test-Path -Path "${script:SRC_DIR}\dist\windows-arm64") { write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-arm64.zip" - Compress-Archive -Path "${script:SRC_DIR}\dist\windows-arm64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-arm64.zip" -Force + Compress-Archive -CompressionLevel Optimal -Path "${script:SRC_DIR}\dist\windows-arm64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-arm64.zip" -Force } } From 2db96c18e72289928e45704a77f96f7bdfaee30f Mon Sep 17 00:00:00 2001 From: Gordon Kamer Date: Wed, 26 Feb 2025 10:40:53 -0800 Subject: [PATCH 02/20] readme: add Nichey to community integrations (#9370) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 548eb244..bc6b8348 100644 --- a/README.md +++ b/README.md @@ -502,6 +502,7 @@ See the [API documentation](./docs/api.md) for all endpoints. - [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs) - [Ollama for Zig](https://github.com/dravenk/ollama-zig) - [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider) +- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic ### Mobile From d7d7e996621ab3d5793d6c1b26b60f0c8c36cd62 Mon Sep 17 00:00:00 2001 From: Jeffrey Morgan Date: Wed, 26 Feb 2025 20:34:44 -0800 Subject: [PATCH 03/20] llama: update llama.cpp vendor code to commit d7cfe1ff (#9356) --- Makefile.sync | 2 +- llama/build-info.cpp | 2 +- llama/llama.cpp/common/common.cpp | 353 +- llama/llama.cpp/common/common.h | 128 +- .../common/json-schema-to-grammar.cpp | 112 +- .../llama.cpp/common/json-schema-to-grammar.h | 16 +- llama/llama.cpp/common/log.cpp | 12 +- llama/llama.cpp/common/log.h | 13 +- llama/llama.cpp/common/sampling.cpp | 120 +- llama/llama.cpp/common/sampling.h | 3 + llama/llama.cpp/examples/llava/clip.cpp | 273 +- llama/llama.cpp/examples/llava/clip.h | 8 + llama/llama.cpp/examples/llava/llava.cpp | 40 +- llama/llama.cpp/include/llama-cpp.h | 8 +- llama/llama.cpp/include/llama.h | 236 +- llama/llama.cpp/src/llama-adapter.cpp | 101 +- llama/llama.cpp/src/llama-adapter.h | 64 +- llama/llama.cpp/src/llama-arch.cpp | 134 +- llama/llama.cpp/src/llama-arch.h | 9 +- llama/llama.cpp/src/llama-chat.cpp | 26 +- llama/llama.cpp/src/llama-chat.h | 2 + llama/llama.cpp/src/llama-context.cpp | 5 +- llama/llama.cpp/src/llama-context.h | 10 +- llama/llama.cpp/src/llama-grammar.cpp | 454 +- llama/llama.cpp/src/llama-grammar.h | 23 +- llama/llama.cpp/src/llama-hparams.cpp | 4 +- llama/llama.cpp/src/llama-hparams.h | 6 +- llama/llama.cpp/src/llama-impl.cpp | 3 +- llama/llama.cpp/src/llama-impl.h | 12 +- llama/llama.cpp/src/llama-kv-cache.cpp | 80 +- llama/llama.cpp/src/llama-kv-cache.h | 2 +- llama/llama.cpp/src/llama-mmap.cpp | 13 +- llama/llama.cpp/src/llama-mmap.h | 3 +- llama/llama.cpp/src/llama-model-loader.cpp | 146 +- llama/llama.cpp/src/llama-model-loader.h | 11 +- llama/llama.cpp/src/llama-model.cpp | 4287 ++++++++++++----- llama/llama.cpp/src/llama-model.h | 298 +- llama/llama.cpp/src/llama-quant.cpp | 71 +- llama/llama.cpp/src/llama-sampling.cpp | 299 +- llama/llama.cpp/src/llama-sampling.h | 22 +- llama/llama.cpp/src/llama-vocab.cpp | 2362 ++++++--- llama/llama.cpp/src/llama-vocab.h | 273 +- llama/llama.cpp/src/llama.cpp | 3697 +++----------- llama/llama.cpp/src/unicode.cpp | 16 +- llama/llama.go | 23 +- llama/mllama.cpp | 1 + llama/patches/0001-cuda.patch | 10 +- llama/patches/0002-pretokenizer.patch | 32 +- llama/patches/0003-embeddings.patch | 10 +- llama/patches/0004-clip-unicode.patch | 8 +- llama/patches/0005-solar-pro.patch | 207 +- llama/patches/0006-conditional-fattn.patch | 4 +- llama/patches/0007-add-mllama-support.patch | 516 +- llama/patches/0008-add-unpad-operator.patch | 59 +- .../0009-fix-deepseek-deseret-regex.patch | 10 +- ...ntain-ordering-for-rules-for-grammar.patch | 6 +- ...sing-arg-in-static-assert-on-windows.patch | 22 - ...sure-KV-cache-is-fully-defragmented.patch} | 50 +- ...se-dynamic-backend-loading-for-clip.patch} | 8 +- ...patch => 0013-sort-devices-by-score.patch} | 2 +- ...arget-ggml-cpu-for-all-cpu-variants.patch} | 6 +- ...atch => 0015-try-catch-backend-load.patch} | 2 +- .../0016-remove-sgemm-global-variables.patch | 55 - ...-filesystem-path-instead-of-wstring.patch} | 2 +- ...remove-amx.patch => 0017-remove-amx.patch} | 6 +- .../0018-fix-clip-compiler-error.patch | 36 + ml/backend/ggml/ggml/include/ggml-backend.h | 2 + ml/backend/ggml/ggml/include/ggml-cpp.h | 1 + ml/backend/ggml/ggml/include/ggml-cpu.h | 4 +- ml/backend/ggml/ggml/include/ggml-metal.h | 2 +- ml/backend/ggml/ggml/include/ggml-vulkan.h | 2 - ml/backend/ggml/ggml/include/ggml.h | 185 +- ml/backend/ggml/ggml/include/gguf.h | 202 + ml/backend/ggml/ggml/src/CMakeLists.txt | 29 +- ml/backend/ggml/ggml/src/ggml-alloc.c | 19 +- ml/backend/ggml/ggml/src/ggml-backend-impl.h | 1 - ml/backend/ggml/ggml/src/ggml-backend-reg.cpp | 5 + ml/backend/ggml/ggml/src/ggml-backend.cpp | 2 +- ml/backend/ggml/ggml/src/ggml-common.h | 2 - .../ggml/ggml/src/ggml-cpu/CMakeLists.txt | 123 +- .../ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp | 2 + .../ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h | 169 +- .../ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c | 2090 ++++++-- ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c | 836 +++- .../ggml/ggml/src/ggml-cpu/ggml-cpu.cpp | 41 +- .../ggml/src/ggml-cpu/llamafile/sgemm.cpp | 836 +++- .../ggml/ggml/src/ggml-cuda/CMakeLists.txt | 12 +- .../ggml/ggml/src/ggml-cuda/binbcast.cu | 54 +- ml/backend/ggml/ggml/src/ggml-cuda/common.cuh | 203 +- ml/backend/ggml/ggml/src/ggml-cuda/concat.cu | 2 +- ml/backend/ggml/ggml/src/ggml-cuda/convert.cu | 2 +- .../ggml/ggml/src/ggml-cuda/cp-async.cuh | 46 + ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu | 97 +- .../ggml/src/ggml-cuda/cross-entropy-loss.cu | 175 +- .../ggml/ggml/src/ggml-cuda/fattn-common.cuh | 178 +- .../ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh | 1021 ++++ .../ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu | 23 +- .../ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu | 25 +- .../ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh | 8 +- .../ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh | 8 +- .../ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu | 648 +++ .../ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 542 +-- ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu | 200 +- ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu | 157 +- .../ggml/ggml/src/ggml-cuda/getrows.cuh | 3 + .../ggml/ggml/src/ggml-cuda/ggml-cuda.cu | 761 +-- ml/backend/ggml/ggml/src/ggml-cuda/gla.cu | 93 + ml/backend/ggml/ggml/src/ggml-cuda/gla.cuh | 3 + ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh | 517 +- ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu | 13 +- ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh | 557 +-- ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu | 150 +- ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu | 3 +- ml/backend/ggml/ggml/src/ggml-cuda/norm.cu | 236 +- ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh | 2 + .../ggml/ggml/src/ggml-cuda/out-prod.cu | 41 +- ml/backend/ggml/ggml/src/ggml-cuda/pad.cu | 2 +- ml/backend/ggml/ggml/src/ggml-cuda/rope.cu | 366 +- ml/backend/ggml/ggml/src/ggml-cuda/rope.cuh | 2 + ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu | 143 +- .../ggml/ggml/src/ggml-cuda/softmax.cuh | 2 + ml/backend/ggml/ggml/src/ggml-cuda/sum.cu | 4 +- ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 10 + ...ttn-mma-f16-instance-ncols1_16-ncols2_1.cu | 10 + ...ttn-mma-f16-instance-ncols1_16-ncols2_2.cu | 10 + ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 10 + ...ttn-mma-f16-instance-ncols1_32-ncols2_1.cu | 10 + ...ttn-mma-f16-instance-ncols1_32-ncols2_2.cu | 10 + ...attn-mma-f16-instance-ncols1_4-ncols2_2.cu | 10 + ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 10 + ...ttn-mma-f16-instance-ncols1_64-ncols2_1.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_1.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_2.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 10 + ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 10 + ml/backend/ggml/ggml/src/ggml-cuda/unary.cu | 36 + ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh | 3 + ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu | 4 +- .../ggml/ggml/src/ggml-hip/CMakeLists.txt | 25 +- ml/backend/ggml/ggml/src/ggml-impl.h | 29 +- .../src/ggml-metal/ggml-metal-embed.metal | 113 +- .../ggml/ggml/src/ggml-metal/ggml-metal.m | 242 +- .../ggml/ggml/src/ggml-metal/ggml-metal.metal | 111 +- ml/backend/ggml/ggml/src/ggml.c | 1484 +----- ml/backend/ggml/ggml/src/ggml_darwin_arm64.go | 2 +- ml/backend/ggml/ggml/src/gguf.cpp | 1329 +++++ 149 files changed, 18215 insertions(+), 11009 deletions(-) delete mode 100644 llama/patches/0011-fix-missing-arg-in-static-assert-on-windows.patch rename llama/patches/{0012-llama-Ensure-KV-cache-is-fully-defragmented.patch => 0011-llama-Ensure-KV-cache-is-fully-defragmented.patch} (84%) rename llama/patches/{0013-use-dynamic-backend-loading-for-clip.patch => 0012-use-dynamic-backend-loading-for-clip.patch} (94%) rename llama/patches/{0014-sort-devices-by-score.patch => 0013-sort-devices-by-score.patch} (98%) rename llama/patches/{0015-add-phony-target-ggml-cpu-for-all-cpu-variants.patch => 0014-add-phony-target-ggml-cpu-for-all-cpu-variants.patch} (86%) rename llama/patches/{0017-try-catch-backend-load.patch => 0015-try-catch-backend-load.patch} (99%) delete mode 100644 llama/patches/0016-remove-sgemm-global-variables.patch rename llama/patches/{0018-use-std-filesystem-path-instead-of-wstring.patch => 0016-use-std-filesystem-path-instead-of-wstring.patch} (99%) rename llama/patches/{0019-remove-amx.patch => 0017-remove-amx.patch} (89%) create mode 100644 llama/patches/0018-fix-clip-compiler-error.patch create mode 100644 ml/backend/ggml/ggml/include/gguf.h create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/cp-async.cuh create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/gla.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/gla.cuh create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu create mode 100644 ml/backend/ggml/ggml/src/gguf.cpp diff --git a/Makefile.sync b/Makefile.sync index 00728274..2bb8d4ab 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -1,6 +1,6 @@ UPSTREAM=https://github.com/ggerganov/llama.cpp.git WORKDIR=llama/vendor -FETCH_HEAD=46e3556e01b824e52395fb050b29804b6cff2a7c +FETCH_HEAD=d7cfe1ffe0f435d0048a6058d529daf76e072d9c .PHONY: help help: diff --git a/llama/build-info.cpp b/llama/build-info.cpp index e169b926..58c1b080 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "46e3556e01b824e52395fb050b29804b6cff2a7c"; +char const *LLAMA_COMMIT = "d7cfe1ffe0f435d0048a6058d529daf76e072d9c"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/common/common.cpp b/llama/llama.cpp/common/common.cpp index 4bb140ee..d2b0d50e 100644 --- a/llama/llama.cpp/common/common.cpp +++ b/llama/llama.cpp/common/common.cpp @@ -2,6 +2,9 @@ #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING #endif +#include "ggml.h" +#include "gguf.h" + #include "common.h" #include "log.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: @@ -70,6 +73,22 @@ #include #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 + +// +// CURL utils +// + +using curl_ptr = std::unique_ptr; + +// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one +struct curl_slist_ptr { + struct curl_slist * ptr = nullptr; + ~curl_slist_ptr() { + if (ptr) { + curl_slist_free_all(ptr); + } + } +}; #endif // LLAMA_USE_CURL using json = nlohmann::ordered_json; @@ -464,6 +483,48 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } +std::string string_join(const std::vector & values, const std::string & separator) { + std::ostringstream result; + for (size_t i = 0; i < values.size(); ++i) { + if (i > 0) { + result << separator; + } + result << values[i]; + } + return result.str(); +} + +std::vector string_split(const std::string & str, const std::string & delimiter) { + std::vector parts; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + parts.push_back(str.substr(start, end - start)); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + parts.push_back(str.substr(start)); + + return parts; +} + +std::string string_repeat(const std::string & str, size_t n) { + if (n == 0) { + return ""; + } + + std::string result; + result.reserve(str.length() * n); + + for (size_t i = 0; i < n; ++i) { + result += str; + } + + return result; +} + std::string string_from(bool value) { return value ? "true" : "false"; } @@ -846,7 +907,7 @@ struct common_init_result common_init_from_params(common_params & params) { } else if (!params.model_url.empty()) { model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams); } else { - model = llama_load_model_from_file(params.model.c_str(), mparams); + model = llama_model_load_from_file(params.model.c_str(), mparams); } if (model == NULL) { @@ -854,26 +915,28 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } + const llama_vocab * vocab = llama_model_get_vocab(model); + if (params.reranking) { bool ok = true; - if (llama_token_bos(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have a BOS token, reranking will not work\n", __func__); + if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); ok = false; } - if (llama_token_eos(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__); + if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__); ok = false; } - if (llama_token_sep(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have a SEP token, reranking will not work\n", __func__); + if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); ok = false; } if (!ok) { - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -881,10 +944,10 @@ struct common_init_result common_init_from_params(common_params & params) { auto cparams = common_context_params_to_llama(params); - llama_context * lctx = llama_new_context_with_model(model, cparams); + llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -895,25 +958,26 @@ struct common_init_result common_init_from_params(common_params & params) { if (!params.control_vectors.empty()) { if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; - if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); + if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model); const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } - int err = llama_control_vector_apply(lctx, - cvec.data.data(), - cvec.data.size(), - cvec.n_embd, - params.control_vector_layer_start, - params.control_vector_layer_end); + int err = llama_apply_adapter_cvec( + lctx, + cvec.data.data(), + cvec.data.size(), + cvec.n_embd, + params.control_vector_layer_start, + params.control_vector_layer_end); if (err) { llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -921,12 +985,12 @@ struct common_init_result common_init_from_params(common_params & params) { // load and optionally apply lora adapters for (auto & la : params.lora_adapters) { - llama_lora_adapter_ptr lora; - lora.reset(llama_lora_adapter_init(model, la.path.c_str())); + llama_adapter_lora_ptr lora; + lora.reset(llama_adapter_lora_init(model, la.path.c_str())); if (lora == nullptr) { LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); llama_free(lctx); - llama_free_model(model); + llama_model_free(model); return iparams; } @@ -935,17 +999,17 @@ struct common_init_result common_init_from_params(common_params & params) { } if (!params.lora_init_without_apply) { - common_lora_adapters_apply(lctx, params.lora_adapters); + common_set_adapter_lora(lctx, params.lora_adapters); } - if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); params.sampling.ignore_eos = false; } if (params.sampling.ignore_eos) { - for (llama_token i = 0; i < llama_n_vocab(model); i++) { - if (llama_token_is_eog(model, i)) { + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); params.sampling.logit_bias.push_back({i, -INFINITY}); } @@ -966,8 +1030,9 @@ struct common_init_result common_init_from_params(common_params & params) { LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); std::vector tmp; - llama_token bos = llama_token_bos(model); - llama_token eos = llama_token_eos(model); + llama_token bos = llama_vocab_bos(vocab); + llama_token eos = llama_vocab_eos(vocab); + // some models (e.g. T5) don't have a BOS token if (bos != LLAMA_TOKEN_NULL) { tmp.push_back(bos); @@ -982,7 +1047,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (llama_model_has_encoder(model)) { llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); - if (decoder_start_token_id == -1) { + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { decoder_start_token_id = bos; } tmp.clear(); @@ -1002,11 +1067,11 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } -void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora) { - llama_lora_adapter_clear(ctx); +void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { + llama_clear_adapter_lora(ctx); for (auto & la : lora) { if (la.scale != 0.0f) { - llama_lora_adapter_set(ctx, la.ptr, la.scale); + llama_set_adapter_lora(ctx, la.ptr, la.scale); } } } @@ -1020,7 +1085,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } - mparams.rpc_servers = params.rpc_servers.c_str(); mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; mparams.tensor_split = params.tensor_split; @@ -1123,7 +1187,8 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { // Initialize libcurl - std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; if (!curl) { LOG_ERR("%s: error initializing libcurl\n", __func__); return false; @@ -1137,11 +1202,9 @@ static bool common_download_file(const std::string & url, const std::string & pa // Check if hf-token or bearer-token was specified if (!hf_token.empty()) { - std::string auth_header = "Authorization: Bearer "; - auth_header += hf_token.c_str(); - struct curl_slist *http_headers = NULL; - http_headers = curl_slist_append(http_headers, auth_header.c_str()); - curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers); + std::string auth_header = "Authorization: Bearer " + hf_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); } #if defined(_WIN32) @@ -1411,7 +1474,7 @@ struct llama_model * common_load_model_from_url( } } - return llama_load_model_from_file(local_path.c_str(), params); + return llama_model_load_from_file(local_path.c_str(), params); } struct llama_model * common_load_model_from_hf( @@ -1437,6 +1500,80 @@ struct llama_model * common_load_model_from_hf( return common_load_model_from_url(model_url, local_path, hf_token, params); } +/** + * Allow getting the HF file from the HF repo with tag (like ollama), for example: + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q4 + * - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M + * - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s + * Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo) + * + * Return pair of (with "repo" already having tag removed) + * + * Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files. + */ +std::pair common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) { + auto parts = string_split(hf_repo_with_tag, ':'); + std::string tag = parts.size() > 1 ? parts.back() : "latest"; + std::string hf_repo = parts[0]; + if (string_split(hf_repo, '/').size() != 2) { + throw std::invalid_argument("error: invalid HF repo format, expected /[:quant]\n"); + } + + // fetch model info from Hugging Face Hub API + json model_info; + curl_ptr curl(curl_easy_init(), &curl_easy_cleanup); + curl_slist_ptr http_headers; + std::string res_str; + std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag; + curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); + typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data); + auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t { + static_cast(data)->append((char * ) ptr, size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast(write_callback)); + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str); +#if defined(_WIN32) + curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); +#endif + if (!hf_token.empty()) { + std::string auth_header = "Authorization: Bearer " + hf_token; + http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str()); + } + // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response + http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json"); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); + + CURLcode res = curl_easy_perform(curl.get()); + + if (res != CURLE_OK) { + throw std::runtime_error("error: cannot make GET request to HF API"); + } + + long res_code; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code); + if (res_code == 200) { + model_info = json::parse(res_str); + } else if (res_code == 401) { + throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token"); + } else { + throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str())); + } + + // check response + if (!model_info.contains("ggufFile")) { + throw std::runtime_error("error: model does not have ggufFile"); + } + json & gguf_file = model_info.at("ggufFile"); + if (!gguf_file.contains("rfilename")) { + throw std::runtime_error("error: ggufFile does not have rfilename"); + } + + return std::make_pair(hf_repo, gguf_file.at("rfilename")); +} + #else struct llama_model * common_load_model_from_url( @@ -1458,6 +1595,11 @@ struct llama_model * common_load_model_from_hf( return nullptr; } +std::pair common_get_hf_file(const std::string &, const std::string &) { + LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); + return std::make_pair("", ""); +} + #endif // LLAMA_USE_CURL // @@ -1556,21 +1698,23 @@ std::vector common_tokenize( const std::string & text, bool add_special, bool parse_special) { - return common_tokenize(llama_get_model(ctx), text, add_special, parse_special); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_tokenize(vocab, text, add_special, parse_special); } std::vector common_tokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const std::string & text, bool add_special, bool parse_special) { // upper limit for the number of tokens int n_tokens = text.length() + 2 * add_special; std::vector result(n_tokens); - n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); @@ -1579,12 +1723,18 @@ std::vector common_tokenize( } std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_token_to_piece(vocab, token, special); +} + +std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' - const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); if (n_chars < 0) { piece.resize(-n_chars); - int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); GGML_ASSERT(check == -n_chars); } else { @@ -1594,13 +1744,19 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token return piece; } -std::string common_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { +std::string common_detokenize(const struct llama_context * ctx, const std::vector & tokens, bool special) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + return common_detokenize(vocab, tokens, special); +} + +std::string common_detokenize(const struct llama_vocab * vocab, const std::vector & tokens, bool special) { std::string text; text.resize(std::max(text.capacity(), tokens.size())); - int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); if (n_chars < 0) { text.resize(-n_chars); - n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization } @@ -1610,103 +1766,6 @@ std::string common_detokenize(llama_context * ctx, const std::vector 0) { - std::vector model_template(res + 1, 0); - llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size()); - return std::string(model_template.data(), model_template.size() - 1); - } - return ""; -} - -bool common_chat_verify_template(const std::string & tmpl) { - llama_chat_message chat[] = {{"user", "test"}}; - int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); - return res >= 0; -} - -std::string common_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, - const std::vector & msgs, - bool add_ass) { - int alloc_size = 0; - bool fallback = false; // indicate if we must fallback to default chatml - std::vector chat; - for (auto & msg : msgs) { - chat.push_back({msg.role.c_str(), msg.content.c_str()}); - alloc_size += (msg.role.size() + msg.content.size()) * 1.25; - } - - const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); - std::vector buf(alloc_size); - - // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - - // error: chat template is not supported - if (res < 0) { - if (ptr_tmpl != nullptr) { - // if the custom "tmpl" is not supported, we throw an error - // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() - throw std::runtime_error("this custom template is not supported"); - } else { - // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - fallback = true; - } - } - - // if it turns out that our buffer is too small, we resize it - if ((size_t) res > buf.size()) { - buf.resize(res); - res = llama_chat_apply_template( - fallback ? nullptr : model, - fallback ? "chatml" : ptr_tmpl, - chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - } - - std::string formatted_chat(buf.data(), res); - return formatted_chat; -} - -std::string common_chat_format_single(const struct llama_model * model, - const std::string & tmpl, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass) { - std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false); - std::vector chat_new(past_msg); - // if the past_msg ends with a newline, we must preserve it in the formatted version - if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { - ss << "\n"; - }; - // format chat with new_msg - chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass); - // get the diff part - ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); - return ss.str(); -} - -std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl) { - std::vector msgs = { - {"system", "You are a helpful assistant"}, - {"user", "Hello"}, - {"assistant", "Hi there"}, - {"user", "How are you?"}, - }; - return common_chat_apply_template(model, tmpl, msgs, true); -} - // // KV cache utils // diff --git a/llama/llama.cpp/common/common.h b/llama/llama.cpp/common/common.h index 0d452cf0..efe8e7f7 100644 --- a/llama/llama.cpp/common/common.h +++ b/llama/llama.cpp/common/common.h @@ -4,6 +4,7 @@ #include "llama-cpp.h" +#include #include #include #include @@ -24,11 +25,11 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" -struct common_lora_adapter_info { +struct common_adapter_lora_info { std::string path; float scale; - struct llama_lora_adapter * ptr; + struct llama_adapter_lora * ptr; }; using llama_tokens = std::vector; @@ -103,6 +104,17 @@ enum dimre_method { DIMRE_METHOD_MEAN, }; +enum common_conversation_mode { + COMMON_CONVERSATION_MODE_DISABLED = 0, + COMMON_CONVERSATION_MODE_ENABLED = 1, + COMMON_CONVERSATION_MODE_AUTO = 2, +}; + +struct common_grammar_trigger { + std::string word; + bool at_start; +}; + // sampling parameters struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler @@ -128,6 +140,7 @@ struct common_params_sampling { int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float top_n_sigma = -1.00f;// -1.0 = disabled float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate bool ignore_eos = false; @@ -148,7 +161,11 @@ struct common_params_sampling { COMMON_SAMPLER_TYPE_TEMPERATURE, }; - std::string grammar; // optional BNF-like grammar to constrain sampling + std::string grammar; // optional BNF-like grammar to constrain sampling + bool grammar_lazy = false; + std::vector grammar_trigger_words; // optional trigger words to trigger lazy grammar + std::vector grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens. + std::set preserved_tokens; std::vector logit_bias; // logit biases to apply @@ -161,15 +178,19 @@ struct common_params_speculative { int32_t n_ctx = 0; // draft context size int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding - int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding + int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) float p_split = 0.1f; // speculative decoding split probability - float p_min = 0.9f; // minimum speculative decoding probability (greedy) + float p_min = 0.75f; // minimum speculative decoding probability (greedy) struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - std::string model = ""; // draft model for speculative decoding // NOLINT + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + + std::string model = ""; // draft model for speculative decoding // NOLINT + std::string model_url = ""; // model url to download // NOLINT }; struct common_params_vocoder { @@ -178,6 +199,13 @@ struct common_params_vocoder { std::string model = ""; // model path // NOLINT std::string model_url = ""; // model url to download // NOLINT + + bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT +}; + +enum common_reasoning_format { + COMMON_REASONING_FORMAT_NONE, + COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content` }; struct common_params { @@ -240,14 +268,13 @@ struct common_params { std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT - std::string rpc_servers = ""; // comma separated list of RPC servers // NOLINT std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; - bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) - std::vector lora_adapters; // lora adapter path with user defined scale + bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) + std::vector lora_adapters; // lora adapter path with user defined scale std::vector control_vectors; // control vector with user defined scale @@ -271,11 +298,11 @@ struct common_params { bool kl_divergence = false; // compute KL divergence bool usage = false; // print usage + bool completion = false; // print source-able completion script bool use_color = false; // use color to distinguish generations and inputs bool special = false; // enable special token output bool interactive = false; // interactive mode bool interactive_first = false; // wait for user input immediately - bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool prompt_cache_all = false; // save user input and generations to prompt cache bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it @@ -301,6 +328,8 @@ struct common_params { ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V + common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; + // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector // NOLINT std::vector image; // path to image file(s) @@ -322,7 +351,9 @@ struct common_params { std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT std::string chat_template = ""; // NOLINT + bool use_jinja = false; // NOLINT bool enable_chat_template = true; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; std::vector api_keys; @@ -401,13 +432,13 @@ bool set_process_priority(enum ggml_sched_priority prio); // #ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# if defined(__MINGW32__) && !defined(__clang__) +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# else +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +# endif #else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) +# define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) #endif LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2) @@ -416,6 +447,10 @@ std::string string_format(const char * fmt, ...); std::string string_strip(const std::string & str); std::string string_get_sortable_timestamp(); +std::string string_join(const std::vector & values, const std::string & separator); +std::vector string_split(const std::string & str, const std::string & delimiter); +std::string string_repeat(const std::string & str, size_t n); + void string_replace_all(std::string & s, const std::string & search, const std::string & replace); template @@ -454,6 +489,11 @@ static bool string_starts_with(const std::string & str, return str.rfind(prefix, 0) == 0; } +static bool string_ends_with(const std::string & str, + const std::string & suffix) { // While we wait for C++20's std::string::ends_with... + return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); @@ -481,7 +521,7 @@ struct common_init_result { llama_model_ptr model; llama_context_ptr context; - std::vector lora; + std::vector lora; }; struct common_init_result common_init_from_params(common_params & params); @@ -495,6 +535,7 @@ struct llama_model * common_load_model_from_url( const std::string & local_path, const std::string & hf_token, const struct llama_model_params & params); + struct llama_model * common_load_model_from_hf( const std::string & repo, const std::string & remote_path, @@ -502,8 +543,12 @@ struct llama_model * common_load_model_from_hf( const std::string & hf_token, const struct llama_model_params & params); +std::pair common_get_hf_file( + const std::string & hf_repo_with_tag, + const std::string & hf_token); + // clear LoRA adapters from context, then apply new list of adapters -void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora); +void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); // // Batch utils @@ -541,7 +586,7 @@ std::vector common_tokenize( bool parse_special = false); std::vector common_tokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const std::string & text, bool add_special, bool parse_special = false); @@ -553,48 +598,23 @@ std::string common_token_to_piece( llama_token token, bool special = true); +std::string common_token_to_piece( + const struct llama_vocab * vocab, + llama_token token, + bool special = true); + // detokenizes a vector of tokens into a string // should work similar to Python's `tokenizer.decode` // optionally renders special/control tokens std::string common_detokenize( - llama_context * ctx, + const struct llama_context * ctx, const std::vector & tokens, bool special = true); -// -// Chat template utils -// - -// same with llama_chat_message, but uses std::string -struct common_chat_msg { - std::string role; - std::string content; -}; - -// Get the built-in chat template for the model. Return empty string if not present. -std::string common_get_builtin_chat_template(const struct llama_model * model); - -// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool common_chat_verify_template(const std::string & tmpl); - -// CPP wrapper for llama_chat_apply_template -// If the built-in template is not supported, we default to chatml -// If the custom "tmpl" is not supported, we throw an error -std::string common_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, - const std::vector & chat, - bool add_ass); - -// Format single message, while taking into account the position of that message in chat history -std::string common_chat_format_single(const struct llama_model * model, - const std::string & tmpl, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass); - -// Returns an example of formatted chat -std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl); +std::string common_detokenize( + const struct llama_vocab * vocab, + const std::vector & tokens, + bool special = true); // // KV cache utils diff --git a/llama/llama.cpp/common/json-schema-to-grammar.cpp b/llama/llama.cpp/common/json-schema-to-grammar.cpp index 2a8dbd22..30c28808 100644 --- a/llama/llama.cpp/common/json-schema-to-grammar.cpp +++ b/llama/llama.cpp/common/json-schema-to-grammar.cpp @@ -1,4 +1,6 @@ #include "json-schema-to-grammar.h" +#include "common.h" + #include #include #include @@ -11,11 +13,6 @@ using json = nlohmann::ordered_json; -template -static std::string join(Iterator begin, Iterator end, const std::string & separator); - -static std::string repeat(const std::string & str, size_t n); - static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { auto has_max = max_items != std::numeric_limits::max(); @@ -128,8 +125,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & if (sub_len > 0) { auto from_sub = from.substr(i + 1); auto to_sub = to.substr(i + 1); - auto sub_zeros = repeat("0", sub_len); - auto sub_nines = repeat("9", sub_len); + auto sub_zeros = string_repeat("0", sub_len); + auto sub_nines = string_repeat("9", sub_len); auto to_reached = false; out << "("; @@ -188,8 +185,8 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream & auto max_digits = max_s.length(); for (auto digits = min_digits; digits < max_digits; digits++) { - uniform_range(min_s, repeat("9", digits)); - min_s = "1" + repeat("0", digits); + uniform_range(min_s, string_repeat("9", digits)); + min_s = "1" + string_repeat("0", digits); out << " | "; } uniform_range(min_s, max_s); @@ -318,49 +315,6 @@ std::unordered_map GRAMMAR_LITERAL_ESCAPES = { std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; -template -std::string join(Iterator begin, Iterator end, const std::string & separator) { - std::ostringstream result; - if (begin != end) { - result << *begin; - for (Iterator it = begin + 1; it != end; ++it) { - result << separator << *it; - } - } - return result.str(); -} - -static std::vector split(const std::string & str, const std::string & delimiter) { - std::vector tokens; - size_t start = 0; - size_t end = str.find(delimiter); - - while (end != std::string::npos) { - tokens.push_back(str.substr(start, end - start)); - start = end + delimiter.length(); - end = str.find(delimiter, start); - } - - tokens.push_back(str.substr(start)); - - return tokens; -} - -static std::string repeat(const std::string & str, size_t n) { - if (n == 0) { - return ""; - } - - std::string result; - result.reserve(str.length() * n); - - for (size_t i = 0; i < n; ++i) { - result += str; - } - - return result; -} - static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) { std::smatch match; std::string result; @@ -389,6 +343,7 @@ static std::string format_literal(const std::string & literal) { class SchemaConverter { private: + friend std::string build_grammar(const std::function & cb, const common_grammar_options & options); std::function _fetch_json; bool _dotall; std::unordered_map _rules; @@ -418,7 +373,7 @@ private: for (size_t i = 0; i < alt_schemas.size(); i++) { rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); } - return join(rules.begin(), rules.end(), " | "); + return string_join(rules, " | "); } std::string _visit_pattern(const std::string & pattern, const std::string & name) { @@ -481,7 +436,7 @@ private: for (const auto & item : ret) { results.push_back(to_rule(item)); } - return std::make_pair(join(results.begin(), results.end(), " "), false); + return std::make_pair(string_join(results, " "), false); }; while (i < length) { @@ -539,7 +494,7 @@ private: } curly_brackets += '}'; i++; - auto nums = split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); + auto nums = string_split(curly_brackets.substr(1, curly_brackets.length() - 2), ","); int min_times = 0; int max_times = std::numeric_limits::max(); try { @@ -809,10 +764,11 @@ private: public: SchemaConverter( const std::function & fetch_json, - bool dotall) + bool dotall, + bool compact_spaces) : _fetch_json(fetch_json), _dotall(dotall) { - _rules["space"] = SPACE_RULE; + _rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE; } void resolve_refs(json & schema, const std::string & url) { @@ -854,7 +810,7 @@ public: return; } std::string pointer = ref.substr(ref.find('#') + 1); - std::vector tokens = split(pointer, "/"); + std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { std::string sel = tokens[i]; if (target.is_null() || !target.contains(sel)) { @@ -905,7 +861,7 @@ public: for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); } else if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { @@ -1019,10 +975,10 @@ public: void check_errors() { if (!_errors.empty()) { - throw std::runtime_error("JSON schema conversion failed:\n" + join(_errors.begin(), _errors.end(), "\n")); + throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n")); } if (!_warnings.empty()) { - fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", join(_warnings.begin(), _warnings.end(), "; ").c_str()); + fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str()); } } @@ -1035,11 +991,35 @@ public: } }; -std::string json_schema_to_grammar(const json & schema) { - SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); - auto copy = schema; - converter.resolve_refs(copy, "input"); - converter.visit(copy, ""); +std::string json_schema_to_grammar(const json & schema, bool force_gbnf) { +#ifdef LLAMA_USE_LLGUIDANCE + if (!force_gbnf) { + return "%llguidance {}\nstart: %json " + schema.dump(); + } +#else + (void)force_gbnf; +#endif // LLAMA_USE_LLGUIDANCE + return build_grammar([&](const common_grammar_builder & callbacks) { + auto copy = schema; + callbacks.resolve_refs(copy); + callbacks.add_schema("", copy); + }); +} + +std::string build_grammar(const std::function & cb, const common_grammar_options & options) { + SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces); + common_grammar_builder builder { + /* .add_rule = */ [&](const std::string & name, const std::string & rule) { + return converter._add_rule(name, rule); + }, + /* .add_schema = */ [&](const std::string & name, const nlohmann::ordered_json & schema) { + return converter.visit(schema, name == "root" ? "" : name); + }, + /* .resolve_refs = */ [&](nlohmann::ordered_json & schema) { + converter.resolve_refs(schema, ""); + } + }; + cb(builder); converter.check_errors(); return converter.format_grammar(); } diff --git a/llama/llama.cpp/common/json-schema-to-grammar.h b/llama/llama.cpp/common/json-schema-to-grammar.h index 41623b34..62a3b0a4 100644 --- a/llama/llama.cpp/common/json-schema-to-grammar.h +++ b/llama/llama.cpp/common/json-schema-to-grammar.h @@ -5,4 +5,18 @@ #define JSON_ASSERT GGML_ASSERT #include "json.hpp" -std::string json_schema_to_grammar(const nlohmann::ordered_json& schema); +std::string json_schema_to_grammar(const nlohmann::ordered_json & schema, + bool force_gbnf = false); + +struct common_grammar_builder { + std::function add_rule; + std::function add_schema; + std::function resolve_refs; +}; + +struct common_grammar_options { + bool dotall = false; + bool compact_spaces = false; +}; + +std::string build_grammar(const std::function & cb, const common_grammar_options & options = {}); diff --git a/llama/llama.cpp/common/log.cpp b/llama/llama.cpp/common/log.cpp index 04c7c0ed..52b31470 100644 --- a/llama/llama.cpp/common/log.cpp +++ b/llama/llama.cpp/common/log.cpp @@ -1,5 +1,6 @@ #include "log.h" +#include #include #include #include @@ -14,16 +15,6 @@ void common_log_set_verbosity_thold(int verbosity) { common_log_verbosity_thold = verbosity; } -#define LOG_COL_DEFAULT "\033[0m" -#define LOG_COL_BOLD "\033[1m" -#define LOG_COL_RED "\033[31m" -#define LOG_COL_GREEN "\033[32m" -#define LOG_COL_YELLOW "\033[33m" -#define LOG_COL_BLUE "\033[34m" -#define LOG_COL_MAGENTA "\033[35m" -#define LOG_COL_CYAN "\033[36m" -#define LOG_COL_WHITE "\033[37m" - static int64_t t_us() { return std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()).count(); } @@ -206,6 +197,7 @@ public: vsnprintf(entry.msg.data(), entry.msg.size(), ss.str().c_str(), args_copy); } #endif + va_end(args_copy); } entry.level = level; diff --git a/llama/llama.cpp/common/log.h b/llama/llama.cpp/common/log.h index 66605cc6..c56bb50d 100644 --- a/llama/llama.cpp/common/log.h +++ b/llama/llama.cpp/common/log.h @@ -2,9 +2,20 @@ #include "ggml.h" // for ggml_log_level +#define LOG_CLR_TO_EOL "\033[K\r" +#define LOG_COL_DEFAULT "\033[0m" +#define LOG_COL_BOLD "\033[1m" +#define LOG_COL_RED "\033[31m" +#define LOG_COL_GREEN "\033[32m" +#define LOG_COL_YELLOW "\033[33m" +#define LOG_COL_BLUE "\033[34m" +#define LOG_COL_MAGENTA "\033[35m" +#define LOG_COL_CYAN "\033[36m" +#define LOG_COL_WHITE "\033[37m" + #ifndef __GNUC__ # define LOG_ATTRIBUTE_FORMAT(...) -#elif defined(__MINGW32__) +#elif defined(__MINGW32__) && !defined(__clang__) # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) #else # define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) diff --git a/llama/llama.cpp/common/sampling.cpp b/llama/llama.cpp/common/sampling.cpp index e83a971c..37a0d9c8 100644 --- a/llama/llama.cpp/common/sampling.cpp +++ b/llama/llama.cpp/common/sampling.cpp @@ -113,7 +113,10 @@ struct common_sampler { void set_logits(struct llama_context * ctx, int idx) { const auto * logits = llama_get_logits_ith(ctx, idx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); cur.resize(n_vocab); @@ -131,24 +134,47 @@ std::string common_params_sampling::print() const { snprintf(result, sizeof(result), "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" "\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n" - "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n" + "\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n" "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", penalty_last_n, penalty_repeat, penalty_freq, penalty_present, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, - top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp, + top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp, mirostat, mirostat_eta, mirostat_tau); return std::string(result); } struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); lparams.no_perf = params.no_perf; + struct llama_sampler * grmr; + if (params.grammar.compare(0, 11, "%llguidance") == 0) { +#ifdef LLAMA_USE_LLGUIDANCE + grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); +#else + GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); +#endif // LLAMA_USE_LLGUIDANCE + } else { + std::vector trigger_words; + trigger_words.reserve(params.grammar_trigger_words.size()); + for (const auto & str : params.grammar_trigger_words) { + trigger_words.push_back(str.word.c_str()); + } + + grmr = params.grammar_lazy + ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", + trigger_words.data(), trigger_words.size(), + params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) + : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); + } + auto * result = new common_sampler { /* .params = */ params, - /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), + /* .grmr = */ grmr, /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, @@ -157,56 +183,62 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(result->chain, llama_sampler_init_logit_bias( - llama_n_vocab(model), + llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); if (params.mirostat == 0) { - for (const auto & cnstr : params.samplers) { - switch (cnstr) { - case COMMON_SAMPLER_TYPE_DRY: - { - std::vector c_breakers; - c_breakers.reserve(params.dry_sequence_breakers.size()); - for (const auto & str : params.dry_sequence_breakers) { - c_breakers.push_back(str.c_str()); - } + if (params.top_n_sigma >= 0) { + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); + } else { + for (const auto & cnstr : params.samplers) { + switch (cnstr) { + case COMMON_SAMPLER_TYPE_DRY: + { + std::vector c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto & str : params.dry_sequence_breakers) { + c_breakers.push_back(str.c_str()); + } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); - } - break; - case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); - break; - case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); - break; - case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); - break; - case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); - break; - default: - GGML_ASSERT(false && "unknown sampler type"); + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + } + break; + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_XTC: + llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + break; + case COMMON_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + case COMMON_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); + break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; + default: + GGML_ASSERT(false && "unknown sampler type"); + } } } llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); diff --git a/llama/llama.cpp/common/sampling.h b/llama/llama.cpp/common/sampling.h index 348911b1..2064421d 100644 --- a/llama/llama.cpp/common/sampling.h +++ b/llama/llama.cpp/common/sampling.h @@ -102,3 +102,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr); std::vector common_sampler_types_from_names(const std::vector & names, bool allow_alt_names); std::vector common_sampler_types_from_chars(const std::string & chars); + +llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, + const char * grammar_kind, const char * grammar_data); diff --git a/llama/llama.cpp/examples/llava/clip.cpp b/llama/llama.cpp/examples/llava/clip.cpp index 86b91d5c..54265beb 100644 --- a/llama/llama.cpp/examples/llava/clip.cpp +++ b/llama/llama.cpp/examples/llava/clip.cpp @@ -7,6 +7,7 @@ #include "ggml-cpu.h" #include "ggml-alloc.h" #include "ggml-backend.h" +#include "gguf.h" #ifdef GGML_USE_CUDA #include "ggml-cuda.h" @@ -39,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -114,6 +116,7 @@ static std::string format(const char * fmt, ...) { #define KEY_HAS_VIS_ENC "clip.has_vision_encoder" #define KEY_HAS_LLAVA_PROJ "clip.has_llava_projector" #define KEY_HAS_MINICPMV_PROJ "clip.has_minicpmv_projector" +#define KEY_HAS_GLM_PROJ "clip.has_glm_projector" #define KEY_MINICPMV_VERSION "clip.minicpmv_version" #define KEY_HAS_QWEN2VL_MERGER "clip.has_qwen2vl_merger" #define KEY_USE_GELU "clip.use_gelu" @@ -131,6 +134,7 @@ static std::string format(const char * fmt, ...) { #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" #define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" @@ -172,6 +176,15 @@ static std::string format(const char * fmt, ...) { #define TN_MINICPMV_ATTN "resampler.attn.%s.%s" #define TN_MINICPMV_LN "resampler.ln_%s.%s" +#define TN_GLM_ADAPER_CONV "adapter.conv.%s" +#define TN_GLM_ADAPTER_LINEAR "adapter.linear.linear.%s" +#define TN_GLM_ADAPTER_NORM_1 "adapter.linear.norm1.%s" +#define TN_GLM_ADAPTER_D_H_2_4H "adapter.linear.dense_h_to_4h.%s" +#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" +#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" +#define TN_GLM_BOI_W "adapter.boi" +#define TN_GLM_EOI_W "adapter.eoi" + enum projector_type { PROJECTOR_TYPE_MLP, @@ -179,6 +192,7 @@ enum projector_type { PROJECTOR_TYPE_LDP, PROJECTOR_TYPE_LDPV2, PROJECTOR_TYPE_RESAMPLER, + PROJECTOR_TYPE_GLM_EDGE, PROJECTOR_TYPE_MERGER, PROJECTOR_TYPE_UNKNOWN, }; @@ -188,6 +202,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LDP, "ldp" }, { PROJECTOR_TYPE_LDPV2, "ldpv2"}, { PROJECTOR_TYPE_RESAMPLER, "resampler"}, + { PROJECTOR_TYPE_GLM_EDGE, "adapter"}, { PROJECTOR_TYPE_MERGER, "qwen2vl_merger"}, }; @@ -275,7 +290,7 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { { const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); int arr_n = gguf_get_arr_n(ctx_gguf, i); - const void * data = gguf_get_arr_data(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); std::stringstream ss; ss << "["; for (int j = 0; j < arr_n; j++) { @@ -444,8 +459,9 @@ struct clip_hparams { char mm_patch_merge_type[32] = "flat"; // spatial_unpad or flat (default) - int32_t image_grid_pinpoints[32]; + std::vector image_grid_pinpoints; int32_t image_crop_resolution; + std::unordered_set vision_feature_layer; }; struct clip_layer { @@ -512,6 +528,12 @@ struct clip_vision_model { struct ggml_tensor * mm_4_w = NULL; struct ggml_tensor * mm_4_b = NULL; + //GLMV-Edge projection + struct ggml_tensor * mm_model_adapter_conv_w; + struct ggml_tensor * mm_model_adapter_conv_b; + struct ggml_tensor * boi_w; + struct ggml_tensor * eoi_w; + // MobileVLM projection struct ggml_tensor * mm_model_mlp_1_w; struct ggml_tensor * mm_model_mlp_1_b; @@ -572,12 +594,14 @@ struct clip_ctx { bool has_vision_encoder = false; bool has_llava_projector = false; bool has_minicpmv_projector = false; + bool has_glm_projector = false; bool has_qwen2vl_merger = false; int minicpmv_version = 2; struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; + int32_t max_feature_layer; float image_mean[3]; float image_std[3]; bool use_gelu = false; @@ -644,13 +668,12 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 const int hidden_size = hparams.hidden_size; const int n_head = hparams.n_head; const int d_head = hidden_size / n_head; - int n_layer = hparams.n_layer; const float eps = hparams.eps; int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; const int batch_size = imgs->size; - if (ctx->has_llava_projector || ctx->has_minicpmv_projector) { + if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) { GGML_ASSERT(batch_size == 1); } @@ -730,6 +753,9 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 else if (ctx->minicpmv_version == 3) { pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); } + else if (ctx->minicpmv_version == 4) { + pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1); + } ggml_set_name(pos_embed, "pos_embed"); ggml_set_input(pos_embed); } @@ -742,14 +768,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b); } + std::vector embedding_stack; + const auto & vision_feature_layer = hparams.vision_feature_layer; + // loop over layers - if (ctx->has_minicpmv_projector || ctx->has_qwen2vl_merger) { - // TODO: figure out why we doing thing in this way ??? - n_layer += 1; - } - for (int il = 0; il < n_layer - 1; il++) { + for (int il = 0; il < ctx->max_feature_layer; il++) { struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states + // If this is an embedding feature layer, save the output. + // NOTE: 0 index here refers to the input to the encoder. + if (vision_feature_layer.find(il) != vision_feature_layer.end()) { + embedding_stack.push_back(embeddings); + } + //const size_t nb_q_w = model.layers[il].q_w->nb[0]; // layernorm1 @@ -837,7 +868,6 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 cur = ggml_add(ctx0, embeddings, cur); embeddings = cur; - } // post-layernorm @@ -848,6 +878,19 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); } + // final layer is a vision feature layer + if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) { + embedding_stack.push_back(embeddings); + } + + // If feature layers are explicitly set, stack them (if we have multiple) + if (!embedding_stack.empty()) { + embeddings = embedding_stack[0]; + for (size_t i = 1; i < embedding_stack.size(); i++) { + embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0); + } + } + // llava projector if (ctx->has_llava_projector) { embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]); @@ -1065,6 +1108,11 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 n_head = hidden_size/d_head; num_query = 64; } + else if (ctx->minicpmv_version == 4) { + hidden_size = 3584; + n_head = hidden_size/d_head; + num_query = 64; + } struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b); Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head)); @@ -1099,7 +1147,33 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 GGML_ASSERT(false); } } - else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { + // glm projector + else if (ctx->has_glm_projector) { + if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + size_t gridsz = (size_t)sqrt(embeddings->ne[1]); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3)); + embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]); + embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1); + embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size); + embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3)); + embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b); + //GLU + { + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings); + embeddings = ggml_norm(ctx0, embeddings, eps); + embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b); + embeddings = ggml_gelu_inplace(ctx0, embeddings); + struct ggml_tensor * x = embeddings; + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings); + x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x); + embeddings = ggml_silu_inplace(ctx0, embeddings); + embeddings = ggml_mul(ctx0, embeddings,x); + embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings); + } + } else { + GGML_ABORT("fatel error"); + } + } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size); embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); @@ -1268,6 +1342,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->minicpmv_version = gguf_get_val_i32(ctx, idx); } + idx = gguf_find_key(ctx, KEY_HAS_GLM_PROJ); + if (idx != -1) { + new_clip->has_glm_projector = gguf_get_val_bool(ctx, idx); + } + idx = gguf_find_key(ctx, KEY_HAS_QWEN2VL_MERGER); if (idx != -1) { new_clip->has_qwen2vl_merger = gguf_get_val_bool(ctx, idx); @@ -1292,6 +1371,7 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("%s: vision_encoder: %d\n", __func__, new_clip->has_vision_encoder); LOG_INF("%s: llava_projector: %d\n", __func__, new_clip->has_llava_projector); LOG_INF("%s: minicpmv_projector: %d\n", __func__, new_clip->has_minicpmv_projector); + LOG_INF("%s: glm_projector: %d\n", __func__, new_clip->has_glm_projector); LOG_INF("%s: model size: %.2f MB\n", __func__, model_size / 1024.0 / 1024.0); LOG_INF("%s: metadata size: %.2f MB\n", __func__, ggml_get_mem_size(meta) / 1024.0 / 1024.0); } @@ -1402,14 +1482,26 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { int idx = get_key_idx(ctx, KEY_IMAGE_GRID_PINPOINTS); int n = gguf_get_arr_n(ctx, idx); const int32_t * pinpoints = (const int32_t *)gguf_get_arr_data(ctx, idx); - for (int i = 0; i < 32 && i < n && pinpoints[i] != 0; ++i) { - hparams.image_grid_pinpoints[i] = pinpoints[i]; + for (int i = 0; i < n; ++i) { + hparams.image_grid_pinpoints.push_back(pinpoints[i]); } - if (n < 32) - hparams.image_grid_pinpoints[n] = 0; - } catch (std::runtime_error & /*e*/) { - hparams.image_grid_pinpoints[0]=0; - } + } catch (std::runtime_error & /*e*/) { } + + // Load the vision feature layer indices if they are explicitly provided; + // if multiple vision feature layers are present, the values will be concatenated + // to form the final visual features. + // NOTE: gguf conversions should standardize the values of the vision feature layer to + // be non-negative, since we use -1 to mark values as unset here. + try { + int idx = get_key_idx(ctx, KEY_FEATURE_LAYER); + int n = gguf_get_arr_n(ctx, idx); + + const int32_t * vision_feature_layer = (const int32_t *)gguf_get_arr_data(ctx, idx); + + for (int i = 0; i < n; ++i) { + hparams.vision_feature_layer.insert(vision_feature_layer[i]); + } + } catch (std::runtime_error & /*e*/) { } try { int idx = get_key_idx(ctx, KEY_MM_PATCH_MERGE_TYPE); @@ -1435,6 +1527,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { new_clip->image_std[i] = std_data[i]; } + // Calculate the deepest feature layer based on hparams and projector type + new_clip->max_feature_layer = get_deepest_feature_layer(new_clip); + if (verbosity >= 2) { LOG_INF("\n%s: vision model hparams\n", __func__); LOG_INF("image_size %d\n", hparams.image_size); @@ -1448,8 +1543,13 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { LOG_INF("v_image_mean %f %f %f\n", new_clip->image_mean[0], new_clip->image_mean[1], new_clip->image_mean[2]); LOG_INF("v_image_std %f %f %f\n", new_clip->image_std[0], new_clip->image_std[1], new_clip->image_std[2]); LOG_INF("v_image_grid_pinpoints: "); - for (int i = 0; i < 32 && (hparams.image_grid_pinpoints[i] != 0); ++i) { - LOG_INF("%d ", hparams.image_grid_pinpoints[i]); + for (const auto & pp : hparams.image_grid_pinpoints) { + LOG_INF("%d ", pp); + } + LOG_INF("\n"); + LOG_INF("v_vision_feature_layer: "); + for (const auto & feature_layer: hparams.vision_feature_layer) { + LOG_INF("%d ", feature_layer); } LOG_INF("\n"); LOG_INF("v_mm_patch_merge_type: %s\n", hparams.mm_patch_merge_type); @@ -1584,6 +1684,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { vision_model.mm_model_ln_post_w = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "weight")); vision_model.mm_model_ln_post_b = get_tensor(new_clip->ctx_data, format(TN_MINICPMV_LN, "post", "bias")); } + else if (new_clip->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + vision_model.mm_model_adapter_conv_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "weight")); + vision_model.mm_model_adapter_conv_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPER_CONV, "bias")); + vision_model.mm_model_mlp_0_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_LINEAR,"weight")); + vision_model.mm_model_ln_q_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"weight")); + vision_model.mm_model_ln_q_b = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_NORM_1,"bias")); + vision_model.mm_model_mlp_1_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_H_2_4H,"weight")); + vision_model.mm_model_mlp_2_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_GATE,"weight")); + vision_model.mm_model_mlp_3_w = get_tensor(new_clip->ctx_data, format(TN_GLM_ADAPTER_D_4H_2_H,"weight")); + vision_model.boi_w = get_tensor(new_clip->ctx_data, TN_GLM_BOI_W); + vision_model.eoi_w = get_tensor(new_clip->ctx_data, TN_GLM_EOI_W); + } else if (new_clip->proj_type == PROJECTOR_TYPE_MERGER) { vision_model.mm_0_w = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "weight")); vision_model.mm_0_b = get_tensor(new_clip->ctx_data, format(TN_LLAVA_PROJ, 0, "bias")); @@ -1676,11 +1788,11 @@ void clip_image_f32_batch_free(struct clip_image_f32_batch * batch) { } } -static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) { +void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img) { img->nx = nx; img->ny = ny; img->buf.resize(3 * nx * ny); - memcpy(img->buf.data(), data, img->buf.size()); + memcpy(img->buf.data(), rgb_pixels, img->buf.size()); } bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { @@ -1690,7 +1802,7 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { LOG_ERR("%s: failed to load image '%s'\n", __func__, fname); return false; } - build_clip_img_from_data(data, nx, ny, img); + clip_build_img_from_pixels(data, nx, ny, img); stbi_image_free(data); return true; } @@ -1702,7 +1814,7 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length LOG_ERR("%s: failed to decode image bytes\n", __func__); return false; } - build_clip_img_from_data(data, nx, ny, img); + clip_build_img_from_pixels(data, nx, ny, img); stbi_image_free(data); return true; } @@ -2058,6 +2170,7 @@ static std::vector> uhd_slice_image(const clip_imag images[images.size()-1].push_back(patch); } } + clip_image_u8_free(refine_image); } return images; } @@ -2096,6 +2209,13 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli clip_image_f32_free(res); } } + for (size_t i = 0; i < imgs.size(); ++i) { + for (size_t j = 0; j < imgs[i].size(); ++j) { + if (imgs[i][j] != nullptr) { + clip_image_u8_free(imgs[i][j]); + } + } + } return true; } else if (ctx->has_qwen2vl_merger) { @@ -2116,6 +2236,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli return true; } + if (ctx->has_glm_projector) { + res_imgs->size = 1; + res_imgs->data = new clip_image_f32[res_imgs->size]; + clip_image_u8 resized_image; + int32_t sz=ctx->vision_model.hparams.image_size; + bicubic_resize(*img, resized_image,sz,sz); + clip_image_f32 * res = clip_image_f32_init(); + //clip_image_save_to_bmp(resized_image, "resized.bmp"); + normalize_image_u8_to_f32(&resized_image, res, ctx->image_mean, ctx->image_std); + res_imgs->data[0] = *res; + clip_image_f32_free(res); + return true; + } + bool pad_to_square = true; if (!ctx->has_vision_encoder) { LOG_ERR("This gguf file seems to have no vision encoder\n"); @@ -2160,10 +2294,10 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, cli } } } else { - if (params.image_grid_pinpoints[0] != 0) { + if (!params.image_grid_pinpoints.empty()) { // "spatial_unpad" with "anyres" processing for llava-1.6 std::vector> possible_resolutions; - for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i+=2) { + for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) { possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]}); } std::pair best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions); @@ -2301,7 +2435,8 @@ void clip_free(clip_ctx * ctx) { } size_t clip_embd_nbytes(const struct clip_ctx * ctx) { - return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float); + int extra_tokens = ctx->has_glm_projector ? 2 : 0; + return (clip_n_patches(ctx) + extra_tokens) * clip_n_mmproj_embd(ctx) * sizeof(float); } size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) { @@ -2328,7 +2463,14 @@ const char * clip_patch_merge_type(const struct clip_ctx * ctx) { } const int32_t * clip_image_grid(const struct clip_ctx * ctx) { - return ctx->vision_model.hparams.image_grid_pinpoints; + if (ctx->vision_model.hparams.image_grid_pinpoints.size()) { + return &ctx->vision_model.hparams.image_grid_pinpoints.front(); + } + return nullptr; +} + +size_t get_clip_image_grid_size(const struct clip_ctx * ctx) { + return ctx->vision_model.hparams.image_grid_pinpoints.size(); } int clip_n_patches(const struct clip_ctx * ctx) { @@ -2343,7 +2485,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2) { + if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { n_patches /= 4; } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) { if (ctx->minicpmv_version == 2) { @@ -2352,6 +2494,9 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i else if (ctx->minicpmv_version == 3) { n_patches = 64; } + else if (ctx->minicpmv_version == 4) { + n_patches = 64; + } } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { int patch_size = params.patch_size * 2; int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0); @@ -2473,6 +2618,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima if (ctx->has_minicpmv_projector) { GGML_ASSERT(batch_size == 1); } + if (ctx->has_glm_projector) { + GGML_ASSERT(batch_size == 1); + ggml_tensor * boi = ctx->vision_model.boi_w; + ggml_backend_tensor_get(boi,vec,0,ggml_nbytes(boi)); + vec = (float*)(vec+ggml_nelements(boi)); //offset for boi + } // build the inference graph ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true); @@ -2531,8 +2682,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316 struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions"); int* positions_data = (int*)malloc(ggml_nbytes(positions)); - int bucket_coords_h[70]; - int bucket_coords_w[70]; + int bucket_coords_h[1024]; + int bucket_coords_w[1024]; for (int i = 0; i < pos_h; i++){ bucket_coords_h[i] = std::floor(70.0*i/pos_h); } @@ -2560,6 +2711,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima else if (ctx->minicpmv_version == 3) { embed_dim = 3584; } + else if (ctx->minicpmv_version == 4) { + embed_dim = 3584; + } auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h)); float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed)); @@ -2622,11 +2776,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions)); free(positions_data); - { + if (!ctx->has_glm_projector) { struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches"); + // The patches vector is used to get rows to index into the embeds with; + // we should skip dim 0 only if we have CLS to avoid going out of bounds + // when retrieving the rows. + int patch_offset = ctx->has_class_embedding ? 1 : 0; int* patches_data = (int*)malloc(ggml_nbytes(patches)); for (int i = 0; i < num_patches; i++) { - patches_data[i] = i + 1; + patches_data[i] = i + patch_offset; } ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches)); free(patches_data); @@ -2646,14 +2804,19 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); + if (ctx->has_glm_projector) { + //eoi + ggml_tensor * eoi = ctx->vision_model.eoi_w; + int offset = ggml_nelements(embeddings); + ggml_backend_tensor_get(eoi, vec+offset, 0, ggml_nbytes(eoi)); + } + return true; } bool clip_model_quantize(const char * fname_inp, const char * fname_out, const int itype) { - ggml_type type = GGML_TYPE_Q4_1; - assert(itype < GGML_TYPE_COUNT); - type = static_cast(itype); + ggml_type type = static_cast(itype); auto * ctx_clip = clip_model_load(fname_inp, 2); @@ -2706,8 +2869,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i } } - // quantize only 2D tensors - quantize &= (ggml_n_dims(cur) == 2); + // quantize only 2D tensors and bigger than block size + quantize &= (ggml_n_dims(cur) == 2) && cur->ne[0] > ggml_blck_size(type); if (quantize) { new_type = type; @@ -2752,7 +2915,8 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i total_size_org += orig_size; total_size_new += new_size; gguf_set_tensor_type(ctx_out, name.c_str(), new_type); - gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); + GGML_ASSERT(gguf_get_tensor_size(ctx_out, gguf_find_tensor(ctx_out, name.c_str())) == new_size); + gguf_set_tensor_data(ctx_out, name.c_str(), new_data); fout.write((const char *)new_data, new_size); size_t pad = GGML_PAD(new_size, gguf_get_alignment(ctx_out)) - new_size; for (size_t j = 0; j < pad; ++j) { @@ -2802,6 +2966,12 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { else if (ctx->minicpmv_version == 3) { return 3584; } + else if (ctx->minicpmv_version == 4) { + return 3584; + } + } + if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE){ + return ctx->vision_model.mm_model_mlp_3_w->ne[1]; } if (ctx->proj_type == PROJECTOR_TYPE_MERGER) { return ctx->vision_model.mm_1_b->ne[0]; @@ -2818,10 +2988,35 @@ int clip_is_minicpmv(const struct clip_ctx * ctx) { return 0; } +bool clip_is_glm(const struct clip_ctx * ctx) { + return ctx->has_glm_projector; +} bool clip_is_qwen2vl(const struct clip_ctx * ctx) { return ctx->has_qwen2vl_merger; } +// Determine the number of encoder layers to iterate over +int get_deepest_feature_layer(const struct clip_ctx * ctx) { + // Get the index of the second to last layer; this is the + // default for models that have a llava projector + const auto & hparams = ctx->vision_model.hparams; + int n_layer = hparams.n_layer - 1; + int deepest_feature_layer = -1; + + // Handle other projectors; incrementing here indicates that we + // should use the last encoder layer for the vision features. + if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) { + n_layer += 1; + } + + // If we set explicit vision feature layers, only go up to the deepest one + for (const auto & feature_layer : hparams.vision_feature_layer) { + if (feature_layer > deepest_feature_layer) { + deepest_feature_layer = feature_layer; + } + } + return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer; +} bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) { clip_image_f32 clip_img; diff --git a/llama/llama.cpp/examples/llava/clip.h b/llama/llama.cpp/examples/llava/clip.h index 1603edd2..f9f80d7d 100644 --- a/llama/llama.cpp/examples/llava/clip.h +++ b/llama/llama.cpp/examples/llava/clip.h @@ -55,6 +55,7 @@ CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx); CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); +CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx); CLIP_API int clip_n_patches (const struct clip_ctx * ctx); CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img); @@ -73,6 +74,9 @@ CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); +/** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */ +CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img); + CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ @@ -89,10 +93,14 @@ CLIP_API bool clip_image_batch_encode(struct clip_ctx * ctx, int n_threads, cons CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out, int itype); CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); +CLIP_API bool clip_is_glm(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); +CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx); + CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); + #ifdef __cplusplus } #endif diff --git a/llama/llama.cpp/examples/llava/llava.cpp b/llama/llama.cpp/examples/llava/llava.cpp index 0f0f3f62..f0e484a1 100644 --- a/llama/llama.cpp/examples/llava/llava.cpp +++ b/llama/llama.cpp/examples/llava/llava.cpp @@ -216,7 +216,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector return true; } -static clip_image_f32 * only_v2_5_reshape_by_patch(clip_image_f32 * image, int patch_size) { +static clip_image_f32 * reshape_by_patch(clip_image_f32 * image, int patch_size) { int width = image->nx; int height = image->ny; int num_patches = (height / patch_size) * (width / patch_size); @@ -277,13 +277,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); } else { - int has_minicpmv_projector = clip_is_minicpmv(ctx_clip); - if (has_minicpmv_projector == 2) { - encoded = clip_image_encode(ctx_clip, n_threads, only_v2_5_reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); - } - else if (has_minicpmv_projector == 3) { - encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[i], image_embd_v[i]); - } + encoded = clip_image_encode(ctx_clip, n_threads, reshape_by_patch(&img_res_v.data[i], patch_size), image_embd_v[i]); } if (!encoded) { @@ -313,6 +307,23 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli load_image_size->height = img->ny; clip_add_load_image_size(ctx_clip, load_image_size); LOG_INF("%s: load_image_size %d %d\n", __func__, load_image_size->width, load_image_size->height); + delete[] img_res_v.data; + img_res_v.size = 0; + img_res_v.data = nullptr; + } + else if (clip_is_glm(ctx_clip)){ + struct clip_image_size * load_image_size = clip_image_size_init(); + load_image_size->width = img_res_v.data[0].nx; + load_image_size->height = img_res_v.data[0].ny; + clip_add_load_image_size(ctx_clip, load_image_size); + + bool encoded = clip_image_encode(ctx_clip, n_threads, &img_res_v.data[0], image_embd); + int pos = int(load_image_size->width/clip_patch_size(ctx_clip)/2); + *n_img_pos = (pos * pos + 2); + if (!encoded){ + LOG_ERR("Unable to encode image \n"); + return false; + } } else if (strcmp(mm_patch_merge_type, "spatial_unpad") != 0) { // flat / default llava-1.5 type embedding @@ -342,9 +353,10 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli LOG_INF("%s: %d segments encoded in %8.2f ms\n", __func__, (int)img_res_v.size, (t_img_enc_batch_us - t_img_enc_start_us) / 1000.0); const int32_t * image_grid = clip_image_grid(ctx_clip); + const size_t num_gridpoints = get_clip_image_grid_size(ctx_clip); std::vector> grid_pinpoints; - for (int i = 0; i < 32 && image_grid[i] != 0; i += 2) { + for (size_t i = 0; i < num_gridpoints; i += 2) { grid_pinpoints.push_back({image_grid[i], image_grid[i+1]}); } @@ -384,7 +396,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama)); auto n_image_embd = clip_n_mmproj_embd(ctx_clip); if (n_image_embd != n_llama_embd) { LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); @@ -394,10 +406,14 @@ bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * } bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { - int num_max_patches = 6; + // Granite vision uses up to 10 patches + base patch + int num_max_patches = 11; if (clip_is_minicpmv(ctx_clip)) { num_max_patches = 10; } + if (clip_is_glm(ctx_clip)) { + num_max_patches = 1; + } float * image_embd; if (clip_is_qwen2vl(ctx_clip)) { // qwen2vl don't split image into chunks, so `num_max_patches` is not needed. @@ -457,7 +473,7 @@ struct llava_embd_batch { }; bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { - int n_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { int n_eval = image_embed->n_image_pos - i; diff --git a/llama/llama.cpp/include/llama-cpp.h b/llama/llama.cpp/include/llama-cpp.h index 1500cb2f..8f636817 100644 --- a/llama/llama.cpp/include/llama-cpp.h +++ b/llama/llama.cpp/include/llama-cpp.h @@ -9,7 +9,7 @@ #include "llama.h" struct llama_model_deleter { - void operator()(llama_model * model) { llama_free_model(model); } + void operator()(llama_model * model) { llama_model_free(model); } }; struct llama_context_deleter { @@ -20,11 +20,11 @@ struct llama_sampler_deleter { void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); } }; -struct llama_lora_adapter_deleter { - void operator()(llama_lora_adapter * lora_adapter) { llama_lora_adapter_free(lora_adapter); } +struct llama_adapter_lora_deleter { + void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } }; typedef std::unique_ptr llama_model_ptr; typedef std::unique_ptr llama_context_ptr; typedef std::unique_ptr llama_sampler_ptr; -typedef std::unique_ptr llama_lora_adapter_ptr; +typedef std::unique_ptr llama_adapter_lora_ptr; diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h index 9f411960..cc948005 100644 --- a/llama/llama.cpp/include/llama.h +++ b/llama/llama.cpp/include/llama.h @@ -34,7 +34,6 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF -// TODO: use everywhere in the implementation #define LLAMA_TOKEN_NULL -1 #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' @@ -57,7 +56,7 @@ extern "C" { // TODO: show sample usage // - // struct llama_vocab; // TODO: add in the future + struct llama_vocab; struct llama_model; struct llama_context; struct llama_sampler; @@ -214,7 +213,7 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported }; - // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -290,9 +289,6 @@ extern "C" { // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() const float * tensor_split; - // comma separated list of RPC servers to use for offloading - const char * rpc_servers; - // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. // If the provided progress_callback returns true, model loading continues. // If it returns false, model loading is immediately aborted. @@ -312,7 +308,7 @@ extern "C" { }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations - // https://github.com/ggerganov/llama.cpp/pull/7544 + // https://github.com/ggml-org/llama.cpp/pull/7544 struct llama_context_params { uint32_t n_ctx; // text context, 0 = from model uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode @@ -325,7 +321,7 @@ extern "C" { enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings - // ref: https://github.com/ggerganov/llama.cpp/pull/2054 + // ref: https://github.com/ggml-org/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model @@ -388,11 +384,10 @@ extern "C" { } llama_chat_message; // lora adapter - // TODO: rename to llama_adapter_lora - struct llama_lora_adapter; + struct llama_adapter_lora; // Helpers for getting default parameters - // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172) + // TODO: update API to start accepting pointers to params structs (https://github.com/ggml-org/llama.cpp/discussions/9172) LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); @@ -403,31 +398,53 @@ extern "C" { // Call once at the start of the program LLAMA_API void llama_backend_init(void); + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_backend_free(void); + //optional: LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); // Optional: an auto threadpool gets created in ggml if not passed explicitly LLAMA_API void llama_attach_threadpool( - struct llama_context * ctx, - ggml_threadpool_t threadpool, - ggml_threadpool_t threadpool_batch); + struct llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch); + LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); - // Call once at the end of the program - currently only used for MPI - LLAMA_API void llama_backend_free(void); + DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_model_params params), + "use llama_model_load_from_file instead"); - LLAMA_API struct llama_model * llama_load_model_from_file( + // Load the model from a file + // If the file is split into multiple parts, the file name must follow this pattern: -%05d-of-%05d.gguf + // If the split file name does not follow this pattern, use llama_model_load_from_splits + LLAMA_API struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params); - // TODO: rename to llama_model_free - LLAMA_API void llama_free_model(struct llama_model * model); + // Load the model from multiple splits (support custom naming scheme) + // The paths must be in the correct order + LLAMA_API struct llama_model * llama_model_load_from_splits( + const char ** paths, + size_t n_paths, + struct llama_model_params params); - // TODO: rename to llama_init_from_model - LLAMA_API struct llama_context * llama_new_context_with_model( + DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), + "use llama_model_free instead"); + + LLAMA_API void llama_model_free(struct llama_model * model); + + LLAMA_API struct llama_context * llama_init_from_model( struct llama_model * model, struct llama_context_params params); + DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params), + "use llama_init_from_model instead"); + // TODO (jmorganca): this should most likely be passed in as part of a batch // and not set on the context for all batches. LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state); @@ -449,20 +466,31 @@ extern "C" { LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); - LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); - LLAMA_API int32_t llama_n_embd (const struct llama_model * model); - LLAMA_API int32_t llama_n_layer (const struct llama_model * model); - LLAMA_API int32_t llama_n_head (const struct llama_model * model); + DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); + DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); + DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead"); + DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead"); - LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead"); - LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); + + LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); + + LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); // Get the model's RoPE frequency scaling factor - LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); + LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); + + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab); + + LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab); // Functions to access the model's GGUF metadata scalar values // - The functions return the length of the string on success, or -1 on failure @@ -488,6 +516,10 @@ extern "C" { // Returns the total size of all the tensors in the model in bytes LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + // Get the default chat template. Returns nullptr if not available + // If name is NULL, returns the default chat template + LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); + // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); @@ -515,34 +547,31 @@ extern "C" { // // Load a LoRA adapter from file - // TODO: rename to llama_adapter_lora_init - LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init( + LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( struct llama_model * model, const char * path_lora); + // Manually free a LoRA adapter + // Note: loaded adapters will be free when the associated model is deleted + LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); + + // The following functions operate on a llama_context, hence the naming: llama_verb_... + // Add a loaded LoRA adapter to given context // This will not modify model's weight - // TODO: rename to llama_set_adapter_lora - LLAMA_API int32_t llama_lora_adapter_set( + LLAMA_API int32_t llama_set_adapter_lora( struct llama_context * ctx, - struct llama_lora_adapter * adapter, + struct llama_adapter_lora * adapter, float scale); // Remove a specific LoRA adapter from given context // Return -1 if the adapter is not present in the context - // TODO: rename to llama_rm_adapter_lora - LLAMA_API int32_t llama_lora_adapter_remove( + LLAMA_API int32_t llama_rm_adapter_lora( struct llama_context * ctx, - struct llama_lora_adapter * adapter); + struct llama_adapter_lora * adapter); // Remove all LoRA adapters from given context - // TODO: rename to llama_clear_adapter_lora - LLAMA_API void llama_lora_adapter_clear(struct llama_context * ctx); - - // Manually free a LoRA adapter - // Note: loaded adapters will be free when the associated model is deleted - // TODO: rename to llama_adapter_lora_free - LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter); + LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. @@ -550,9 +579,8 @@ extern "C" { // to an n_embd x n_layers buffer starting from layer 1. // il_start and il_end are the layer range the vector should apply to (both inclusive) // See llama_control_vector_load in common to load a control vector. - // TODO: rename to llama_adapter_cvec_apply - LLAMA_API int32_t llama_control_vector_apply( - struct llama_context * lctx, + LLAMA_API int32_t llama_apply_adapter_cvec( + struct llama_context * ctx, const float * data, size_t len, int32_t n_embd, @@ -908,41 +936,60 @@ extern "C" { // Vocab // - LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); + LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token); - LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); + LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token); - LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token); + LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token); // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) - LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); + LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token); // Identify if Token Id is a control token or a render-able token - LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token); + LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token); // Special tokens - LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence - LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence - LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn - LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification - LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator - LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line - LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding + LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence + LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence + LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn + LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator + LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line + LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding - LLAMA_API bool llama_add_bos_token(const struct llama_model * model); - LLAMA_API bool llama_add_eos_token(const struct llama_model * model); + LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); + LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); - // infill tokens - DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead"); - DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead"); - DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead"); + LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab); - LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model); - LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model); - LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model); - LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model); - LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model); - LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model); + DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_text instead"); + DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); + DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); + DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); + DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead"); + DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead"); + DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead"); + DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead"); + DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead"); + DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead"); + DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead"); + DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead"); + + // CLS is equivalent to BOS + DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification + "use llama_vocab_bos instead"); // // Tokenization @@ -958,7 +1005,7 @@ extern "C" { /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated /// as plaintext. Does not insert a leading space. LLAMA_API int32_t llama_tokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const char * text, int32_t text_len, llama_token * tokens, @@ -972,7 +1019,7 @@ extern "C" { // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') // @param special If true, special tokens are rendered in the output. LLAMA_API int32_t llama_token_to_piece( - const struct llama_model * model, + const struct llama_vocab * vocab, llama_token token, char * buf, int32_t length, @@ -986,7 +1033,7 @@ extern "C" { /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so. /// @param unparse_special If true, special tokens are rendered in the output. LLAMA_API int32_t llama_detokenize( - const struct llama_model * model, + const struct llama_vocab * vocab, const llama_token * tokens, int32_t n_tokens, char * text, @@ -1000,7 +1047,7 @@ extern "C" { /// Apply chat template. Inspired by hf apply_chat_template() on python. /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" - /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template + /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. /// @param chat Pointer to a list of multiple llama_chat_message /// @param n_msg Number of llama_chat_message in this chat @@ -1009,7 +1056,6 @@ extern "C" { /// @param length The size of the allocated buffer /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. LLAMA_API int32_t llama_chat_apply_template( - const struct llama_model * model, const char * tmpl, const struct llama_chat_message * chat, size_t n_msg, @@ -1057,7 +1103,6 @@ extern "C" { // llama_sampler_free(smpl); // // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). - // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab // typedef void * llama_sampler_context_t; @@ -1076,11 +1121,12 @@ extern "C" { }; struct llama_sampler { - struct llama_sampler_i * iface; - llama_sampler_context_t ctx; + const struct llama_sampler_i * iface; + llama_sampler_context_t ctx; }; // mirror of llama_sampler_i: + LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); @@ -1110,7 +1156,7 @@ extern "C" { /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), - "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)"); + "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)"); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); @@ -1118,7 +1164,7 @@ extern "C" { /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep); - /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + /// @details Minimum P sampling as described in https://github.com/ggml-org/llama.cpp/pull/3841 LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. @@ -1133,6 +1179,9 @@ extern "C" { /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); + /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 + LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. @@ -1157,10 +1206,22 @@ extern "C" { float eta); LLAMA_API struct llama_sampler * llama_sampler_init_grammar( - const struct llama_model * model, + const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); + /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 + /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) @@ -1169,8 +1230,9 @@ extern "C" { float penalty_present); // 0.0 = disabled /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 - LLAMA_API struct llama_sampler * llama_sampler_init_dry( - const struct llama_model * model, + LLAMA_API struct llama_sampler * llama_sampler_init_dry( + const struct llama_vocab * vocab, + int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, @@ -1204,7 +1266,7 @@ extern "C" { // 3. discard non-EOG tokens with low prob // 4. if no tokens are left -> pick EOT // - LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model); + LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); diff --git a/llama/llama.cpp/src/llama-adapter.cpp b/llama/llama.cpp/src/llama-adapter.cpp index 9fd7edea..8a080046 100644 --- a/llama/llama.cpp/src/llama-adapter.cpp +++ b/llama/llama.cpp/src/llama-adapter.cpp @@ -1,5 +1,7 @@ #include "llama-adapter.h" +#include "llama-impl.h" +#include "llama-mmap.h" #include "llama-model.h" #include @@ -9,7 +11,7 @@ // vec -struct ggml_tensor * llama_control_vector::tensor_for(int il) const { +struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { return nullptr; } @@ -17,7 +19,7 @@ struct ggml_tensor * llama_control_vector::tensor_for(int il) const { return tensors[il]; } -struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { +struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { ggml_tensor * layer_dir = tensor_for(il); if (layer_dir != nullptr) { cur = ggml_add(ctx, cur, layer_dir); @@ -26,12 +28,12 @@ struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, s return cur; } -static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) { +bool llama_adapter_cvec::init(const llama_model & model) { const auto & hparams = model.hparams; - GGML_ASSERT(cvec.tensors.empty()); - GGML_ASSERT(cvec.ctxs.empty()); - GGML_ASSERT(cvec.bufs.empty()); + GGML_ASSERT(tensors.empty()); + GGML_ASSERT(ctxs.empty()); + GGML_ASSERT(bufs.empty()); // create a context for each buffer type std::map ctx_map; @@ -50,7 +52,7 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const } ctx_map[buft] = ctx; - cvec.ctxs.emplace_back(ctx); + ctxs.emplace_back(ctx); return ctx; } @@ -59,21 +61,21 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const }; // make tensors - cvec.tensors.reserve(hparams.n_layer); - cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0 + tensors.reserve(hparams.n_layer); + tensors.push_back(nullptr); // there's never a tensor for layer 0 for (size_t il = 1; il < hparams.n_layer; il++) { - ggml_backend_buffer_type_t buft = llama_model_select_buft(model, il); + ggml_backend_buffer_type_t buft = model.select_buft(il); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__); return false; } ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); - cvec.tensors.push_back(tensor); + tensors.push_back(tensor); } // allocate tensors / buffers and zero - cvec.bufs.reserve(ctx_map.size()); + bufs.reserve(ctx_map.size()); for (auto it : ctx_map) { ggml_backend_buffer_type_t buft = it.first; ggml_context * ctx = it.second; @@ -83,14 +85,13 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const return false; } ggml_backend_buffer_clear(buf, 0); - cvec.bufs.emplace_back(buf); + bufs.emplace_back(buf); } return true; } -int32_t llama_control_vector_apply( - struct llama_control_vector & cvec, +int32_t llama_adapter_cvec::apply( const llama_model & model, const float * data, size_t len, @@ -101,8 +102,8 @@ int32_t llama_control_vector_apply( if (data == nullptr) { // disable the current control vector (but leave allocated for later) - cvec.layer_start = -1; - cvec.layer_end = -1; + layer_start = -1; + layer_end = -1; return 0; } @@ -111,21 +112,21 @@ int32_t llama_control_vector_apply( return 1; } - if (cvec.tensors.empty()) { - if (!llama_control_vector_init(cvec, model)) { + if (tensors.empty()) { + if (!init(model)) { return 1; } } - cvec.layer_start = il_start; - cvec.layer_end = il_end; + layer_start = il_start; + layer_end = il_end; for (size_t il = 1; il < hparams.n_layer; il++) { - assert(cvec.tensors[il] != nullptr); + assert(tensors[il] != nullptr); const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present if (off + n_embd <= len) { - ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il])); + ggml_backend_tensor_set(tensors[il], data + off, 0, n_embd * ggml_element_size(tensors[il])); } } @@ -134,7 +135,7 @@ int32_t llama_control_vector_apply( // lora -llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) { +llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) { const std::string name(w->name); const auto pos = ab_map.find(name); @@ -145,11 +146,7 @@ llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) { return nullptr; } -void llama_lora_adapter_free(struct llama_lora_adapter * adapter) { - delete adapter; -} - -static void llama_lora_adapter_init_impl(struct llama_model & model, const char * path_lora, struct llama_lora_adapter & adapter) { +static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) { LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); ggml_context * ctx_init; @@ -221,7 +218,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char }; // bundle lora_a and lora_b into pairs - std::map ab_map; + std::map ab_map; auto str_endswith = [](const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; }; @@ -231,17 +228,21 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char if (str_endswith(name, ".lora_a")) { replace_all(name, ".lora_a", ""); if (ab_map.find(name) == ab_map.end()) { - ab_map[name] = llama_lora_weight(cur, nullptr); + ab_map[name] = llama_adapter_lora_weight(cur, nullptr); } else { ab_map[name].a = cur; } } else if (str_endswith(name, ".lora_b")) { replace_all(name, ".lora_b", ""); if (ab_map.find(name) == ab_map.end()) { - ab_map[name] = llama_lora_weight(nullptr, cur); + ab_map[name] = llama_adapter_lora_weight(nullptr, cur); } else { ab_map[name].b = cur; } + } else if (str_endswith(name, "_norm.weight")) { + // TODO: add support for norm vector + // for now, we don't really care because most adapters still work fine without it + continue; } else { throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); } @@ -250,25 +251,33 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char // add tensors for (auto & it : ab_map) { const std::string & name = it.first; - llama_lora_weight & w = it.second; + llama_adapter_lora_weight & w = it.second; + bool is_token_embd = str_endswith(name, "token_embd.weight"); if (!w.a || !w.b) { throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); } // device buft and device ctx - auto * model_tensor = llama_model_get_tensor(model, name.c_str()); + const auto * model_tensor = model.get_tensor(name.c_str()); if (!model_tensor) { - throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model"); + throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); } struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); // validate tensor shape - if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { - throw std::runtime_error("tensor '" + name + "' has incorrect shape"); - } - if (w.a->ne[1] != w.b->ne[0]) { - throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + if (is_token_embd) { + // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() + if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + } else { + if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + if (w.a->ne[1] != w.b->ne[0]) { + throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + } } // save tensor to adapter @@ -276,7 +285,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); ggml_set_name(tensor_a, w.a->name); ggml_set_name(tensor_b, w.b->name); - adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b); + adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b); } // allocate tensors / buffers and zero @@ -318,11 +327,11 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } -struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) { - struct llama_lora_adapter * adapter = new llama_lora_adapter(); +struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) { + struct llama_adapter_lora * adapter = new llama_adapter_lora(); try { - llama_lora_adapter_init_impl(*model, path_lora, *adapter); + llama_adapter_lora_init_impl(*model, path_lora, *adapter); return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); @@ -332,3 +341,7 @@ struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, return nullptr; } + +void llama_adapter_lora_free(struct llama_adapter_lora * adapter) { + delete adapter; +} diff --git a/llama/llama.cpp/src/llama-adapter.h b/llama/llama.cpp/src/llama-adapter.h index 5f1870cc..603fa08f 100644 --- a/llama/llama.cpp/src/llama-adapter.h +++ b/llama/llama.cpp/src/llama-adapter.h @@ -1,66 +1,74 @@ #pragma once -#include "llama-impl.h" -#include "llama-hparams.h" +#include "llama.h" #include "ggml-cpp.h" +#include #include #include +// TODO: pimpl + // // llama_adapter_cvec // -// TODO: rename to llama_adapter_cvec -struct llama_control_vector { - std::vector ctxs; - std::vector bufs; +struct llama_adapter_cvec { + struct ggml_tensor * tensor_for(int il) const; - std::vector tensors; // per layer + struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const; + + int32_t apply( + const llama_model & model, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end); + +private: + bool init(const llama_model & model); int32_t layer_start = -1; int32_t layer_end = -1; - struct ggml_tensor * tensor_for(int il) const; + std::vector ctxs; + std::vector bufs; - struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const; + std::vector tensors; // per layer }; -int32_t llama_control_vector_apply( - struct llama_control_vector & cvec, - const llama_model & model, - const float * data, - size_t len, - int32_t n_embd, - int32_t il_start, - int32_t il_end); - // // llama_adapter_lora // -// TODO: rename to llama_adapter_lora_weight -struct llama_lora_weight { +struct llama_adapter_lora_weight { struct ggml_tensor * a = nullptr; struct ggml_tensor * b = nullptr; - llama_lora_weight() = default; - llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} + // get actual scale based on rank and alpha + float get_scale(float alpha, float adapter_scale) const { + const float rank = (float) b->ne[0]; + const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale; + return scale; + } + + llama_adapter_lora_weight() = default; + llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} }; -// TODO: rename to llama_adapter_lora -struct llama_lora_adapter { +struct llama_adapter_lora { // map tensor name to lora_a_b - std::unordered_map ab_map; + std::unordered_map ab_map; std::vector ctxs; std::vector bufs; float alpha; - llama_lora_adapter() = default; - ~llama_lora_adapter() = default; + llama_adapter_lora() = default; + ~llama_adapter_lora() = default; - llama_lora_weight * get_weight(struct ggml_tensor * w); + llama_adapter_lora_weight * get_weight(struct ggml_tensor * w); }; diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index b35aeb31..b6f20286 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -28,6 +28,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, @@ -57,6 +58,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, @@ -107,25 +109,26 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, + { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, - { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, - { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, - { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, - { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, - { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, - { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, - { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, - { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, - { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, - { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, - { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, - { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, - { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, - { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, - { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, - { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, - { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, - { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, + { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, + { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, + { LLM_KV_ATTENTION_MAX_ALIBI_BIAS, "%s.attention.max_alibi_bias" }, + { LLM_KV_ATTENTION_CLAMP_KQV, "%s.attention.clamp_kqv" }, + { LLM_KV_ATTENTION_KEY_LENGTH, "%s.attention.key_length" }, + { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, + { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, + { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, + { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, + { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, + { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, + { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, + { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, + { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, + { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, + { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -179,6 +182,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, @@ -622,6 +627,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PHIMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_PLAMO, { @@ -1036,6 +1062,9 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, @@ -1182,6 +1211,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, @@ -1199,6 +1229,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, }, }, + { + LLM_ARCH_RWKV6QWEN2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_GRANITE, { @@ -1253,6 +1309,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, + { + LLM_ARCH_SOLAR, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_BSKCN_TV, "bskcn_tv" }, + }, + }, { LLM_ARCH_WAVTOKENIZER_DEC, { @@ -1278,24 +1352,6 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, - { - LLM_ARCH_SOLAR, - { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_BSKCN_TV, "bskcn_tv" }, - }, - }, { LLM_ARCH_UNKNOWN, { @@ -1399,6 +1455,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1455,10 +1512,11 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; -LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {} +LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { - return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) + : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); } std::string LLM_TN_IMPL::str() const { diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index e8235ae0..ec742224 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -32,6 +32,7 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_PHI2, LLM_ARCH_PHI3, + LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, @@ -61,6 +62,7 @@ enum llm_arch { LLM_ARCH_NEMOTRON, LLM_ARCH_EXAONE, LLM_ARCH_RWKV6, + LLM_ARCH_RWKV6QWEN2, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, @@ -111,6 +113,7 @@ enum llm_kv { LLM_KV_TIME_DECAY_EXTRA_DIM, LLM_KV_RESIDUAL_SCALE, LLM_KV_EMBEDDING_SCALE, + LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -177,6 +180,8 @@ enum llm_kv { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, @@ -256,6 +261,7 @@ enum llm_tensor { LLM_TENSOR_TIME_MIX_LERP_V, LLM_TENSOR_TIME_MIX_LERP_R, LLM_TENSOR_TIME_MIX_LERP_G, + LLM_TENSOR_TIME_MIX_LERP_FUSED, LLM_TENSOR_TIME_MIX_FIRST, LLM_TENSOR_TIME_MIX_DECAY, LLM_TENSOR_TIME_MIX_DECAY_W1, @@ -343,9 +349,10 @@ enum llm_tensor_layer { }; struct LLM_KV { - LLM_KV(llm_arch arch); + LLM_KV(llm_arch arch, const char * suffix = nullptr); llm_arch arch; + const char * suffix; std::string operator()(llm_kv kv) const; }; diff --git a/llama/llama.cpp/src/llama-chat.cpp b/llama/llama.cpp/src/llama-chat.cpp index 44670d3d..028a6479 100644 --- a/llama/llama.cpp/src/llama-chat.cpp +++ b/llama/llama.cpp/src/llama-chat.cpp @@ -35,6 +35,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, + { "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR }, { "monarch", LLM_CHAT_TEMPLATE_MONARCH }, @@ -50,6 +51,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, + { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE }, { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, @@ -73,7 +75,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return tmpl.find(haystack) != std::string::npos; }; if (tmpl_contains("<|im_start|>")) { - return LLM_CHAT_TEMPLATE_CHATML; + return tmpl_contains("<|im_sep|>") + ? LLM_CHAT_TEMPLATE_PHI_4 + : LLM_CHAT_TEMPLATE_CHATML; } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { if (tmpl_contains("[SYSTEM_PROMPT]")) { return LLM_CHAT_TEMPLATE_MISTRAL_V7; @@ -112,7 +116,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { return LLM_CHAT_TEMPLATE_PHI_3; } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { - return LLM_CHAT_TEMPLATE_FALCON_3; + return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { return LLM_CHAT_TEMPLATE_ZEPHYR; } else if (tmpl_contains("bos_token + message['role']")) { @@ -149,7 +153,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_MINICPM; } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { return LLM_CHAT_TEMPLATE_DEEPSEEK_2; - } else if (tmpl_contains(LU8("'<|Assistant|>' + message['content'] + '<|end▁of▁sentence|>'"))) { + } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) { return LLM_CHAT_TEMPLATE_DEEPSEEK_3; } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb @@ -269,6 +273,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|assistant|>\n"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_4) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "<|im_sep|>" << message->content << "<|im_end|>"; + } + if (add_ass) { + ss << "<|im_start|>assistant<|im_sep|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) { // Falcon 3 for (auto message : chat) { @@ -429,6 +441,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|assistant|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) { + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>" << "\n" << message->content; + } + if (add_ass) { + ss << "<|assistant|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { // MiniCPM-3B-OpenHermes-2.5-v2-GGUF for (auto message : chat) { diff --git a/llama/llama.cpp/src/llama-chat.h b/llama/llama.cpp/src/llama-chat.h index b8e94d9e..2f6a0e3e 100644 --- a/llama/llama.cpp/src/llama-chat.h +++ b/llama/llama.cpp/src/llama-chat.h @@ -15,6 +15,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V7, LLM_CHAT_TEMPLATE_PHI_3, + LLM_CHAT_TEMPLATE_PHI_4, LLM_CHAT_TEMPLATE_FALCON_3, LLM_CHAT_TEMPLATE_ZEPHYR, LLM_CHAT_TEMPLATE_MONARCH, @@ -30,6 +31,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_CHATGML_3, LLM_CHAT_TEMPLATE_CHATGML_4, + LLM_CHAT_TEMPLATE_GLMEDGE, LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_RWKV_WORLD, diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 9d0e7ca3..7b22fe13 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -1,5 +1,8 @@ #include "llama-context.h" +#include "llama-impl.h" +#include "llama-mmap.h" + #include #include #include @@ -513,7 +516,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { auto * buft = ggml_backend_cpu_buffer_type(); // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory - auto * output_dev = lctx.model.dev_output.dev; + auto * output_dev = lctx.model.dev_output(); auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; if (output_dev_host_buft) { buft = output_dev_host_buft; diff --git a/llama/llama.cpp/src/llama-context.h b/llama/llama.cpp/src/llama-context.h index 4980a60e..cf12c9d7 100644 --- a/llama/llama.cpp/src/llama-context.h +++ b/llama/llama.cpp/src/llama-context.h @@ -22,12 +22,12 @@ struct llama_context { const struct llama_model & model; - struct llama_cparams cparams; - struct llama_sbatch sbatch; // TODO: revisit if needed - struct llama_kv_cache kv_self; - struct llama_control_vector cvec; + struct llama_cparams cparams; + struct llama_sbatch sbatch; // TODO: revisit if needed + struct llama_kv_cache kv_self; + struct llama_adapter_cvec cvec; - std::unordered_map lora_adapters; + std::unordered_map lora; std::vector backends; std::vector> set_n_threads_fns; diff --git a/llama/llama.cpp/src/llama-grammar.cpp b/llama/llama.cpp/src/llama-grammar.cpp index 186dc9a2..98af1ba3 100644 --- a/llama/llama.cpp/src/llama-grammar.cpp +++ b/llama/llama.cpp/src/llama-grammar.cpp @@ -345,194 +345,194 @@ const char * llama_grammar_parser::parse_sequence( size_t last_sym_start = rule.size(); const char * pos = src; - auto handle_repetitions = [&](int min_times, int max_times) { + auto handle_repetitions = [&](int min_times, int max_times) { - if (last_sym_start == rule.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); - } + if (last_sym_start == rule.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } - // apply transformation to previous symbol (last_sym_start to end) according to - // the following rewrite rules: - // S{m,n} --> S S S (m times) S'(n-m) - // S'(x) ::= S S'(x-1) | - // (... n-m definitions of these S' rules ...) - // S'(1) ::= S | - // S{m,} --> S S S (m times) S' - // S' ::= S S' | - // S* --> S{0,} - // --> S' ::= S S' | - // S+ --> S{1,} - // --> S S' - // S' ::= S S' | - // S? --> S{0,1} - // --> S' - // S' ::= S | + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | - llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); - if (min_times == 0) { - rule.resize(last_sym_start); - } else { - // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { - rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); - } - } - - uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; - - llama_grammar_rule rec_rule(prev_rule); - for (int i = 0; i < n_opt; i++) { - rec_rule.resize(prev_rule.size()); - uint32_t rec_rule_id = generate_symbol_id( rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); - } - rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - rec_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule( rec_rule_id, rec_rule); - last_rec_rule_id = rec_rule_id; - } - if (n_opt > 0) { - rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); - } - }; - - while (*pos) { - if (*pos == '"') { // literal string - pos++; - last_sym_start = rule.size(); - while (*pos != '"') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '[') { // char range(s) - pos++; - enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - if (*pos == '^') { - pos++; - start_type = LLAMA_GRETYPE_CHAR_NOT; - } - last_sym_start = rule.size(); - while (*pos != ']') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - enum llama_gretype type = last_sym_start < rule.size() - ? LLAMA_GRETYPE_CHAR_ALT - : start_type; - - rule.push_back({type, char_pair.first}); - if (pos[0] == '-' && pos[1] != ']') { - if (!pos[1]) { - throw std::runtime_error("unexpected end of input"); - } - auto endchar_pair = parse_char(pos + 1); - pos = endchar_pair.second; - rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); - } - } - pos = parse_space(pos + 1, is_nested); - } else if (is_word_char(*pos)) { // rule reference - const char * name_end = parse_name(pos); - uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); - pos = parse_space(name_end, is_nested); - last_sym_start = rule.size(); - rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - } else if (*pos == '(') { // grouping - // parse nested alternates into synthesized rule - pos = parse_space(pos + 1, true); - uint32_t sub_rule_id = generate_symbol_id(rule_name); - pos = parse_alternates(pos, rule_name, sub_rule_id, true); - last_sym_start = rule.size(); - // output reference to synthesized rule - rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - if (*pos != ')') { - throw std::runtime_error(std::string("expecting ')' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '.') { // any char - last_sym_start = rule.size(); - rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, -1); - } else if (*pos == '+') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(1, -1); - } else if (*pos == '?') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, 1); - } else if (*pos == '{') { - pos = parse_space(pos + 1, is_nested); - - if (!is_digit_char(*pos)) { - throw std::runtime_error(std::string("expecting an int at ") + pos); - } - const char * int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - - int max_times = -1; - - if (*pos == '}') { - max_times = min_times; - pos = parse_space(pos + 1, is_nested); - } else if (*pos == ',') { - pos = parse_space(pos + 1, is_nested); - - if (is_digit_char(*pos)) { - const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - } - - if (*pos != '}') { - throw std::runtime_error(std::string("expecting '}' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else { - throw std::runtime_error(std::string("expecting ',' at ") + pos); - } - handle_repetitions(min_times, max_times); - } else { - break; + llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + if (min_times == 0) { + rule.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); } } - return pos; + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + llama_grammar_rule rec_rule(prev_rule); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(prev_rule.size()); + uint32_t rec_rule_id = generate_symbol_id( rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule( rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = rule.size(); + while (*pos != '"') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } + last_sym_start = rule.size(); + while (*pos != ']') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum llama_gretype type = last_sym_start < rule.size() + ? LLAMA_GRETYPE_CHAR_ALT + : start_type; + + rule.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + if (!pos[1]) { + throw std::runtime_error("unexpected end of input"); + } + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(rule_name); + pos = parse_alternates(pos, rule_name, sub_rule_id, true); + last_sym_start = rule.size(); + // output reference to synthesized rule + rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '.') { // any char + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); + } + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + + int max_times = -1; + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); + } else { + break; + } } + return pos; +} const char * llama_grammar_parser::parse_rule(const char * src) { - const char * name_end = parse_name(src); - const char * pos = parse_space(name_end, false); - size_t name_len = name_end - src; - uint32_t rule_id = get_symbol_id(src, name_len); - const std::string name(src, name_len); + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(src, name_len); + const std::string name(src, name_len); - if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { - throw std::runtime_error(std::string("expecting ::= at ") + pos); - } - pos = parse_space(pos + 3, true); - - pos = parse_alternates(pos, name, rule_id, false); - - if (*pos == '\r') { - pos += pos[1] == '\n' ? 2 : 1; - } else if (*pos == '\n') { - pos++; - } else if (*pos) { - throw std::runtime_error(std::string("expecting newline or end at ") + pos); - } - return parse_space(pos, true); + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); +} bool llama_grammar_parser::parse(const char * src) { try { @@ -560,7 +560,7 @@ bool llama_grammar_parser::parse(const char * src) { } } } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + fprintf(stderr, "%s: error parsing grammar: %s\n\n%s\n", __func__, err.what(), src); rules.clear(); return false; } @@ -960,10 +960,28 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy =*/ false, + /* .awaiting_trigger = */ false, + /* .trigger_buffer = */ "", + /* .trigger_tokens = */ {}, + /* .trigger_words = */ {}, + }; } -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { llama_grammar_parser parser; // if there is a grammar, parse it @@ -1035,10 +1053,31 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, } } while (true); + std::vector vec_trigger_tokens; + std::vector vec_trigger_words; + for (size_t i = 0; i < num_trigger_tokens; i++) { + GGML_ASSERT(trigger_tokens != nullptr); + vec_trigger_tokens.push_back(trigger_tokens[i]); + } + for (size_t i = 0; i < num_trigger_words; i++) { + GGML_ASSERT(trigger_words != nullptr); + vec_trigger_words.push_back(trigger_words[i]); + } + // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; + return new llama_grammar { + vocab, + std::move(vec_rules), + std::move(stacks), + /* .partial_utf8 = */ {}, + /* .lazy = */ lazy, + /* .awaiting_trigger = */ lazy, + /* .trigger_buffer = */ "", + std::move(vec_trigger_tokens), + std::move(vec_trigger_words), + }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1055,6 +1094,11 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.rules, grammar.stacks, grammar.partial_utf8, + grammar.lazy, + grammar.awaiting_trigger, + grammar.trigger_buffer, + grammar.trigger_tokens, + grammar.trigger_words, }; // redirect elements in stacks to point to new rules @@ -1076,6 +1120,10 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { GGML_ASSERT(grammar.vocab != nullptr); + if (grammar.awaiting_trigger) { + return; + } + bool allow_eog = false; for (const auto & stack : grammar.stacks) { if (stack.empty()) { @@ -1092,9 +1140,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; - const std::string & piece = grammar.vocab->cache_token_to_piece.at(id); + const std::string & piece = grammar.vocab->token_to_piece(id); - if (llama_token_is_eog_impl(*grammar.vocab, id)) { + if (grammar.vocab->is_eog(id)) { if (!allow_eog) { cur_p->data[i].logit = -INFINITY; } @@ -1115,7 +1163,35 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { GGML_ASSERT(grammar.vocab != nullptr); - if (llama_token_is_eog_impl(*grammar.vocab, token)) { + const auto & piece = grammar.vocab->token_to_piece(token); + + if (grammar.awaiting_trigger) { + if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) { + grammar.awaiting_trigger = false; + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, piece); + LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); + return; + } else { + // TODO: consider a smarter incremental substring search algorithm (store last position to search from). + grammar.trigger_buffer += piece; + for (const auto & word : grammar.trigger_words) { + auto pos = grammar.trigger_buffer.find(word); + if (pos != std::string::npos) { + grammar.awaiting_trigger = false; + auto constrained_str = grammar.trigger_buffer.substr(pos); + grammar.trigger_buffer.clear(); + llama_grammar_accept_str(grammar, constrained_str); + LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str()); + return; + } + } + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); + return; + } + } + + if (grammar.vocab->is_eog(token)) { for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; @@ -1124,8 +1200,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); + llama_grammar_accept_str(grammar, piece); +} +void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) { // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; @@ -1135,5 +1213,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token } grammar.partial_utf8 = decoded.second; - GGML_ASSERT(!grammar.stacks.empty()); + if (grammar.stacks.empty()) { + throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece); + } } diff --git a/llama/llama.cpp/src/llama-grammar.h b/llama/llama.cpp/src/llama-grammar.h index f8b40c65..b143d834 100644 --- a/llama/llama.cpp/src/llama-grammar.h +++ b/llama/llama.cpp/src/llama-grammar.h @@ -114,6 +114,15 @@ struct llama_grammar { // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + + // lazy grammars wait for trigger words or tokens before constraining the sampling. + // we still have trigger_tokens for non-lazy grammars to force printing of special trigger tokens. + // (useful e.g. for tool_choice=required) + bool lazy = false; + bool awaiting_trigger = false; // Initialized to true for lazy grammars only + std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. + std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). + std::vector trigger_words; }; // @@ -127,7 +136,15 @@ struct llama_grammar * llama_grammar_init_impl( size_t n_rules, size_t start_rule_index); -struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + bool lazy, + const char ** trigger_words, + size_t num_trigger_words, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); void llama_grammar_free_impl(struct llama_grammar * grammar); @@ -141,3 +158,7 @@ void llama_grammar_apply_impl( void llama_grammar_accept_impl( struct llama_grammar & grammar, llama_token token); + +void llama_grammar_accept_str( + struct llama_grammar & grammar, + const std::string & piece); diff --git a/llama/llama.cpp/src/llama-hparams.cpp b/llama/llama.cpp/src/llama-hparams.cpp index 42f8a58f..0b841028 100644 --- a/llama/llama.cpp/src/llama-hparams.cpp +++ b/llama/llama.cpp/src/llama-hparams.cpp @@ -54,7 +54,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { uint32_t llama_hparams::n_embd_k_s() const { if (wkv_head_size != 0) { // for RWKV models - return 2 * n_embd; + return token_shift_count * n_embd; } // TODO: maybe support other convolution strides than 1 @@ -82,4 +82,4 @@ bool llama_hparams::n_bskcn(uint32_t n, uint32_t il) const { bool llama_hparams::cross_attention_layers(uint32_t il) const { return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end(); -} +} \ No newline at end of file diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index f826cd9a..05383046 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -30,7 +30,6 @@ struct llama_hparams { bool use_par_res; bool swin_norm; - uint32_t n_vocab = 0; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; uint32_t n_embd_features = 0; @@ -41,8 +40,8 @@ struct llama_hparams { uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; uint32_t n_expert_used = 0; - uint32_t n_vocab_type = 0; // for BERT-style token types uint32_t n_rel_attn_bkts = 0; + uint32_t n_vocab = 0; // for WavTokenizer struct llama_hparams_posnet posnet; @@ -79,6 +78,7 @@ struct llama_hparams { uint32_t time_mix_extra_dim = 0; uint32_t time_decay_extra_dim = 0; uint32_t wkv_head_size = 0; + uint32_t token_shift_count = 2; float rope_attn_factor = 1.0f; float rope_freq_base_train; @@ -141,7 +141,7 @@ struct llama_hparams { // Block skip connection bool n_bskcn(uint32_t n, uint32_t il) const; - // cross attention layers + // cross attention layers bool cross_attention_layers(uint32_t il) const; }; diff --git a/llama/llama.cpp/src/llama-impl.cpp b/llama/llama.cpp/src/llama-impl.cpp index a05ba4f6..6ec709dd 100644 --- a/llama/llama.cpp/src/llama-impl.cpp +++ b/llama/llama.cpp/src/llama-impl.cpp @@ -1,5 +1,6 @@ #include "llama-impl.h" +#include "gguf.h" #include "llama.h" #include @@ -138,7 +139,7 @@ std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { { const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); int arr_n = gguf_get_arr_n(ctx_gguf, i); - const void * data = gguf_get_arr_data(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); std::stringstream ss; ss << "["; for (int j = 0; j < arr_n; j++) { diff --git a/llama/llama.cpp/src/llama-impl.h b/llama/llama.cpp/src/llama-impl.h index 12d1fb08..02b1d07f 100644 --- a/llama/llama.cpp/src/llama-impl.h +++ b/llama/llama.cpp/src/llama-impl.h @@ -6,13 +6,13 @@ #include #ifdef __GNUC__ -#ifdef __MINGW32__ -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# if defined(__MINGW32__) && !defined(__clang__) +# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +# else +# define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +# endif #else -#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) -#endif -#else -#define LLAMA_ATTRIBUTE_FORMAT(...) +# define LLAMA_ATTRIBUTE_FORMAT(...) #endif // diff --git a/llama/llama.cpp/src/llama-kv-cache.cpp b/llama/llama.cpp/src/llama-kv-cache.cpp index cf814dbe..b541c5a3 100644 --- a/llama/llama.cpp/src/llama-kv-cache.cpp +++ b/llama/llama.cpp/src/llama-kv-cache.cpp @@ -72,39 +72,6 @@ bool llama_kv_cache_init( cache.v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { - // for cross attention layers - if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const llama_model::buft_list_t * buft_list; - if (offload) { - buft_list = model.dev_layer.at(i).buft_list; - } else { - buft_list = &model.cpu_buft_list; - } - ggml_backend_buffer_type_t buft = select_buft(*buft_list, - [&](ggml_context * ctx) { - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) { - return k; - } - ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type); - }); - ggml_context * ctx = ctx_for_buft(buft); - - if (!ctx) { - LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); - return false; - } - ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); - ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - cache.k_l.push_back(k); - cache.v_l.push_back(v); - continue; - } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); @@ -112,7 +79,7 @@ bool llama_kv_cache_init( ggml_backend_buffer_type_t buft; if (offload) { - auto * dev = model.dev_layer.at(i).dev; + auto * dev = model.dev_layer(i); buft = ggml_backend_dev_buffer_type(dev); } else { buft = ggml_backend_cpu_buffer_type(); @@ -124,8 +91,17 @@ bool llama_kv_cache_init( return false; } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k, *v; + + // for cross attention layers + if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) { + k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i)); + v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i)); + } else { + k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + } + ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.k_l.push_back(k); @@ -152,10 +128,10 @@ bool llama_kv_cache_init( struct llama_kv_cache_slot_info llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_ubatch & batch) { - const uint32_t n_tokens = batch.n_tokens; - const uint32_t n_seqs = batch.n_seqs; - const uint32_t n_seq_tokens = batch.n_seq_tokens; + const struct llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; if (cache.recurrent) { // For recurrent state architectures (like Mamba or RWKV), @@ -163,16 +139,16 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // A slot should be always be contiguous. // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs); int32_t min = cache.size - 1; int32_t max = 0; // everything should fit if all seq_ids are smaller than the max for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = batch.n_seq_id[s]; + const uint32_t n_seq_id = ubatch.n_seq_id[s]; for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = batch.seq_id[s][j]; + const llama_seq_id seq_id = ubatch.seq_id[s][j]; if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { // too big seq_id @@ -231,7 +207,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // find usable cell range for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = batch.seq_id[s][0]; + const llama_seq_id seq_id = ubatch.seq_id[s][0]; llama_kv_cell & seq_meta = cache.cells[seq_id]; bool has_cell = false; if (seq_meta.tail >= 0) { @@ -270,7 +246,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // gather and re-order for (uint32_t s = 0; s < n_seqs; ++s) { int32_t dst_id = s + min; - int32_t src_id = cache.cells[batch.seq_id[s][0]].tail; + int32_t src_id = cache.cells[ubatch.seq_id[s][0]].tail; if (dst_id != src_id) { llama_kv_cell & dst_cell = cache.cells[dst_id]; llama_kv_cell & src_cell = cache.cells[src_id]; @@ -291,7 +267,7 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // update the pos of the used seqs for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; int32_t cell_id = s + min; llama_kv_cell & cell = cache.cells[cell_id]; @@ -299,12 +275,12 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); } cell.pos = last_pos; cell.seq_id.clear(); - for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = batch.seq_id[s][j]; + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; cell.seq_id.insert(seq_id); cache.cells[seq_id].tail = cell_id; } @@ -358,10 +334,10 @@ struct llama_kv_cache_slot_info llama_kv_cache_find_slot( for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t i = 0; i < n_seq_tokens; ++i) { uint32_t k = s*n_seq_tokens + i; - cache.cells[cache.head + k].pos = batch.pos[k]; + cache.cells[cache.head + k].pos = ubatch.pos[k]; - for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { - cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { + cache.cells[cache.head + k].seq_id.insert(ubatch.seq_id[s][j]); } } } diff --git a/llama/llama.cpp/src/llama-kv-cache.h b/llama/llama.cpp/src/llama-kv-cache.h index dca6f399..1ed688e3 100644 --- a/llama/llama.cpp/src/llama-kv-cache.h +++ b/llama/llama.cpp/src/llama-kv-cache.h @@ -37,7 +37,7 @@ struct llama_kv_cache { bool can_shift = false; // Note: The value of head isn't only used to optimize searching - // for a free KV slot. llama_decode_internal also uses it, so it + // for a free KV slot. llama_decode_impl also uses it, so it // cannot be freely changed after a slot has been allocated. uint32_t head = 0; uint32_t size = 0; diff --git a/llama/llama.cpp/src/llama-mmap.cpp b/llama/llama.cpp/src/llama-mmap.cpp index a9932633..b716630a 100644 --- a/llama/llama.cpp/src/llama-mmap.cpp +++ b/llama/llama.cpp/src/llama-mmap.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #ifdef __has_include #if __has_include() @@ -35,7 +36,7 @@ // TODO: consider moving to llama-impl.h if needed in more places #if defined(_WIN32) -std::string llama_format_win_err(DWORD err) { +static std::string llama_format_win_err(DWORD err) { LPSTR buf; size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); @@ -241,12 +242,16 @@ llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } size_t llama_file::size() const { return pimpl->size; } -int llama_file::fileno() const { +int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); +#else +#if defined(fileno) + return fileno(pimpl->fp); #else return ::fileno(pimpl->fp); #endif +#endif } void llama_file::seek(size_t offset, int whence) const { pimpl->seek(offset, whence); } @@ -265,7 +270,7 @@ struct llama_mmap::impl { impl(struct llama_file * file, size_t prefetch, bool numa) { size = file->size(); - int fd = file->fileno(); + int fd = file->file_id(); int flags = MAP_SHARED; if (numa) { prefetch = 0; } #ifdef __linux__ @@ -357,7 +362,7 @@ struct llama_mmap::impl { size = file->size(); - HANDLE hFile = (HANDLE) _get_osfhandle(file->fileno()); + HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); diff --git a/llama/llama.cpp/src/llama-mmap.h b/llama/llama.cpp/src/llama-mmap.h index 6bcddee8..4e5aec3f 100644 --- a/llama/llama.cpp/src/llama-mmap.h +++ b/llama/llama.cpp/src/llama-mmap.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -18,7 +19,7 @@ struct llama_file { size_t tell() const; size_t size() const; - int fileno() const; + int file_id() const; // fileno overload void seek(size_t offset, int whence) const; diff --git a/llama/llama.cpp/src/llama-model-loader.cpp b/llama/llama.cpp/src/llama-model-loader.cpp index b12d6566..45d08721 100644 --- a/llama/llama.cpp/src/llama-model-loader.cpp +++ b/llama/llama.cpp/src/llama-model-loader.cpp @@ -7,6 +7,10 @@ #include #include +static const size_t kiB = 1024; +static const size_t MiB = 1024*kiB; +static const size_t GiB = 1024*MiB; + const char * llama_file_version_name(llama_fver version) { switch (version) { case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; @@ -17,8 +21,78 @@ const char * llama_file_version_name(llama_fver version) { return "unknown"; } +static std::string llama_model_ftype_name(llama_ftype ftype) { + if (ftype & LLAMA_FTYPE_GUESSED) { + return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; + } + + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + + default: return "unknown, may not work"; + } +} + +// return a list of splits for a given path +// for example, given "-00002-of-00004.gguf", returns list of all 4 splits +static std::vector llama_get_list_splits(const std::string & path, const int idx, const int n_split) { + std::vector paths; + std::string split_prefix; + std::vector buf(llama_path_max(), 0); + + { + int ret = llama_split_prefix(buf.data(), buf.size(), path.c_str(), idx, n_split); + if (!ret) { + throw std::runtime_error(format("invalid split file name: %s", path.c_str())); + } + split_prefix = std::string(buf.data(), ret); + } + + if (split_prefix.empty()) { + throw std::runtime_error(format("invalid split file: %s", path.c_str())); + } + + for (int idx = 0; idx < n_split; ++idx) { + int ret = llama_split_path(buf.data(), buf.size(), split_prefix.c_str(), idx, n_split); + paths.push_back(std::string(buf.data(), ret)); + } + + return paths; +} + namespace GGUFMeta { - template + template struct GKV_Base_Type { static constexpr gguf_type gt = gt_; @@ -60,10 +134,11 @@ namespace GGUFMeta { public: static constexpr gguf_type gt = GGUF_TYPE_ARRAY; static ArrayInfo getter(const gguf_context *ctx, const int k) { + const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); return ArrayInfo { - gguf_get_arr_type(ctx, k), + arr_type, size_t(gguf_get_arr_n(ctx, k)), - gguf_get_arr_data(ctx, k), + arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), }; } }; @@ -368,7 +443,12 @@ namespace GGUFMeta { template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required); -llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) { +llama_model_loader::llama_model_loader( + const std::string & fname, + std::vector & splits, + bool use_mmap, + bool check_tensors, + const struct llama_model_kv_override * param_overrides_p) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -380,6 +460,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, } } + // Load the main GGUF struct ggml_context * ctx = NULL; struct gguf_init_params params = { /*.no_alloc = */ true, @@ -415,35 +496,54 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, // Load additional GGML contexts if (n_split > 1) { + // make sure the main file is loaded first uint16_t idx = 0; - get_key(llm_kv(LLM_KV_SPLIT_NO), idx); + const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); + get_key(kv_split_no, idx); if (idx != 0) { - throw std::runtime_error(format("illegal split file: %d, model must be loaded with the first split", idx)); + throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); } - std::vector split_prefix(llama_path_max(), 0); - if (!llama_split_prefix(split_prefix.data(), split_prefix.size(), fname.c_str(), idx, n_split)) { - throw std::runtime_error(format("invalid split file: %s", fname.c_str())); + // generate list of splits if needed + if (splits.empty()) { + splits = llama_get_list_splits(fname, idx, n_split); + } + + // in case user give a custom list of splits, check if it matches the expected number + if (n_split != (uint16_t)splits.size()) { + throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); } if (trace > 0) { LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); } - std::vector split_path(llama_path_max(), 0); + // load other splits for (idx = 1; idx < n_split; idx++) { - llama_split_path(split_path.data(), split_path.size(), split_prefix.data(), idx, n_split); + const char * fname_split = splits[idx].c_str(); struct gguf_init_params split_params = { /*.no_alloc = */ true, /*.ctx = */ &ctx, }; - gguf_context_ptr ctx_gguf { gguf_init_from_file(split_path.data(), split_params) }; + gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; if (!ctx_gguf) { - throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, split_path.data())); + throw std::runtime_error(format("%s: failed to load GGUF split from %s\n", __func__, fname_split)); } - files.emplace_back(new llama_file(split_path.data(), "rb")); + // check idx + { + const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); + if (kid < 0) { + throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); + } + int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); + if (idx_gguf != idx) { + throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + } + } + + files.emplace_back(new llama_file(fname_split, "rb")); contexts.emplace_back(ctx); // Save tensors data offset info of the shard. @@ -556,7 +656,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, const enum gguf_type type = gguf_get_kv_type(meta.get(), i); const std::string type_name = type == GGUF_TYPE_ARRAY - ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) : gguf_type_name(type); std::string value = gguf_kv_to_str(meta.get(), i); @@ -722,7 +822,7 @@ void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps for (const auto & file : files) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa"); - std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn())); + std::unique_ptr mapping = std::make_unique(file.get(), prefetch ? -1 : 0, is_numa_fn()); mmaps_used.emplace_back(mapping->size(), 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); @@ -1011,3 +1111,17 @@ bool llama_model_loader::load_all_data( return true; } + +std::string llama_model_loader::ftype_name() const { + return llama_model_ftype_name(ftype); +} + +void llama_model_loader::print_info() const { + LLAMA_LOG_INFO("%s: file format = %s\n", __func__, llama_file_version_name(fver)); + LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_model_ftype_name(ftype).c_str()); + if (n_bytes < GiB) { + LLAMA_LOG_INFO("%s: file size = %.2f MiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0, n_bytes*8.0/n_elements); + } else { + LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements); + } +} diff --git a/llama/llama.cpp/src/llama-model-loader.h b/llama/llama.cpp/src/llama-model-loader.h index 1ec47819..fe35404b 100644 --- a/llama/llama.cpp/src/llama-model-loader.h +++ b/llama/llama.cpp/src/llama-model-loader.h @@ -90,7 +90,12 @@ struct llama_model_loader { size_t size_data = 0; std::vector> mmaps_used; - llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p); + llama_model_loader( + const std::string & fname, + std::vector & splits, // optional, only need if the split does not follow naming scheme + bool use_mmap, + bool check_tensors, + const struct llama_model_kv_override * param_overrides_p); template typename std::enable_if::value, bool>::type @@ -155,4 +160,8 @@ struct llama_model_loader { llama_mlocks * lmlocks, llama_progress_callback progress_callback, void * progress_callback_user_data); + + std::string ftype_name() const; + + void print_info() const; }; diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 4f9bbf90..21819080 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -1,128 +1,85 @@ #include "llama-model.h" #include "llama-impl.h" +#include "llama-mmap.h" #include "llama-model-loader.h" -#include "unicode.h" // TODO: remove +#include "ggml-cpp.h" #include #include +#include #include +#include #include #include -static const size_t kiB = 1024; -static const size_t MiB = 1024*kiB; -static const size_t GiB = 1024*MiB; - const char * llm_type_name(llm_type type) { switch (type) { - case MODEL_14M: return "14M"; - case MODEL_17M: return "17M"; - case MODEL_22M: return "22M"; - case MODEL_33M: return "33M"; - case MODEL_60M: return "60M"; - case MODEL_70M: return "70M"; - case MODEL_80M: return "80M"; - case MODEL_109M: return "109M"; - case MODEL_137M: return "137M"; - case MODEL_160M: return "160M"; - case MODEL_220M: return "220M"; - case MODEL_250M: return "250M"; - case MODEL_270M: return "270M"; - case MODEL_335M: return "335M"; - case MODEL_410M: return "410M"; - case MODEL_450M: return "450M"; - case MODEL_770M: return "770M"; - case MODEL_780M: return "780M"; - case MODEL_0_5B: return "0.5B"; - case MODEL_1B: return "1B"; - case MODEL_1_3B: return "1.3B"; - case MODEL_1_4B: return "1.4B"; - case MODEL_1_5B: return "1.5B"; - case MODEL_1_6B: return "1.6B"; - case MODEL_2B: return "2B"; - case MODEL_2_8B: return "2.8B"; - case MODEL_3B: return "3B"; - case MODEL_4B: return "4B"; - case MODEL_6B: return "6B"; - case MODEL_6_9B: return "6.9B"; - case MODEL_7B: return "7B"; - case MODEL_8B: return "8B"; - case MODEL_9B: return "9B"; - case MODEL_11B: return "11B"; - case MODEL_12B: return "12B"; - case MODEL_13B: return "13B"; - case MODEL_14B: return "14B"; - case MODEL_15B: return "15B"; - case MODEL_16B: return "16B"; - case MODEL_20B: return "20B"; - case MODEL_30B: return "30B"; - case MODEL_32B: return "32B"; - case MODEL_34B: return "34B"; - case MODEL_35B: return "35B"; - case MODEL_40B: return "40B"; - case MODEL_65B: return "65B"; - case MODEL_70B: return "70B"; - case MODEL_236B: return "236B"; - case MODEL_314B: return "314B"; - case MODEL_671B: return "671B"; - case MODEL_SMALL: return "0.1B"; - case MODEL_MEDIUM: return "0.4B"; - case MODEL_LARGE: return "0.8B"; - case MODEL_XL: return "1.5B"; - case MODEL_A1_7B: return "A1.7B"; - case MODEL_A2_7B: return "A2.7B"; - case MODEL_8x7B: return "8x7B"; - case MODEL_8x22B: return "8x22B"; - case MODEL_16x12B: return "16x12B"; - case MODEL_10B_128x3_66B: return "10B+128x3.66B"; - case MODEL_57B_A14B: return "57B.A14B"; - case MODEL_27B: return "27B"; - default: return "?B"; - } -} - -static std::string llama_model_ftype_name(llama_ftype ftype) { - if (ftype & LLAMA_FTYPE_GUESSED) { - return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; - } - - switch (ftype) { - case LLAMA_FTYPE_ALL_F32: return "all F32"; - case LLAMA_FTYPE_MOSTLY_F16: return "F16"; - case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; - case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; - case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; - case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; - case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; - case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; - case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; - case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; - case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; - case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; - case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; - - default: return "unknown, may not work"; + case LLM_TYPE_14M: return "14M"; + case LLM_TYPE_17M: return "17M"; + case LLM_TYPE_22M: return "22M"; + case LLM_TYPE_33M: return "33M"; + case LLM_TYPE_60M: return "60M"; + case LLM_TYPE_70M: return "70M"; + case LLM_TYPE_80M: return "80M"; + case LLM_TYPE_109M: return "109M"; + case LLM_TYPE_137M: return "137M"; + case LLM_TYPE_160M: return "160M"; + case LLM_TYPE_220M: return "220M"; + case LLM_TYPE_250M: return "250M"; + case LLM_TYPE_270M: return "270M"; + case LLM_TYPE_335M: return "335M"; + case LLM_TYPE_410M: return "410M"; + case LLM_TYPE_450M: return "450M"; + case LLM_TYPE_770M: return "770M"; + case LLM_TYPE_780M: return "780M"; + case LLM_TYPE_0_5B: return "0.5B"; + case LLM_TYPE_1B: return "1B"; + case LLM_TYPE_1_3B: return "1.3B"; + case LLM_TYPE_1_4B: return "1.4B"; + case LLM_TYPE_1_5B: return "1.5B"; + case LLM_TYPE_1_6B: return "1.6B"; + case LLM_TYPE_2B: return "2B"; + case LLM_TYPE_2_8B: return "2.8B"; + case LLM_TYPE_3B: return "3B"; + case LLM_TYPE_4B: return "4B"; + case LLM_TYPE_6B: return "6B"; + case LLM_TYPE_6_9B: return "6.9B"; + case LLM_TYPE_7B: return "7B"; + case LLM_TYPE_8B: return "8B"; + case LLM_TYPE_9B: return "9B"; + case LLM_TYPE_11B: return "11B"; + case LLM_TYPE_12B: return "12B"; + case LLM_TYPE_13B: return "13B"; + case LLM_TYPE_14B: return "14B"; + case LLM_TYPE_15B: return "15B"; + case LLM_TYPE_16B: return "16B"; + case LLM_TYPE_20B: return "20B"; + case LLM_TYPE_30B: return "30B"; + case LLM_TYPE_32B: return "32B"; + case LLM_TYPE_34B: return "34B"; + case LLM_TYPE_35B: return "35B"; + case LLM_TYPE_40B: return "40B"; + case LLM_TYPE_65B: return "65B"; + case LLM_TYPE_70B: return "70B"; + case LLM_TYPE_236B: return "236B"; + case LLM_TYPE_314B: return "314B"; + case LLM_TYPE_671B: return "671B"; + case LLM_TYPE_SMALL: return "0.1B"; + case LLM_TYPE_MEDIUM: return "0.4B"; + case LLM_TYPE_LARGE: return "0.8B"; + case LLM_TYPE_XL: return "1.5B"; + case LLM_TYPE_A1_7B: return "A1.7B"; + case LLM_TYPE_A2_7B: return "A2.7B"; + case LLM_TYPE_8x7B: return "8x7B"; + case LLM_TYPE_8x22B: return "8x22B"; + case LLM_TYPE_16x12B: return "16x12B"; + case LLM_TYPE_16x3_8B: return "16x3.8B"; + case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B"; + case LLM_TYPE_57B_A14B: return "57B.A14B"; + case LLM_TYPE_27B: return "27B"; + default: return "?B"; } } @@ -134,44 +91,6 @@ static const char * llama_expert_gating_func_name(llama_expert_gating_func_type } } -std::string llama_model_arch_name (const llama_model & model) { - return llm_arch_name(model.arch); -} - -std::string llama_model_type_name (const llama_model & model) { - return llm_type_name(model.type); -} - -std::string llama_model_ftype_name(const llama_model & model) { - return llama_model_ftype_name(model.ftype); -} - -ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il) { - return select_buft( - *model.dev_layer.at(il).buft_list, - [&](ggml_context * ctx) { - ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd); - ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd); - return ggml_add(ctx, cur, layer_dir); - }); -} - -struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name) { - auto it = std::find_if(model.tensors_by_name.begin(), model.tensors_by_name.end(), - [name](const std::pair & it) { - return it.first == name; - }); - if (it == model.tensors_by_name.end()) { - return nullptr; - } - - return it->second; -} - -size_t llama_model_max_nodes(const llama_model & model) { - return std::max(8192, model.tensors_by_name.size()*5); -} - static const std::map LLAMA_ROPE_SCALING_TYPES = { { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, @@ -189,37 +108,284 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } -// NOTE: avoid ever using this except for building the token_to_piece caches -static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { - std::string piece; - piece.resize(piece.capacity()); // using string internal cache - const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); - if (n_chars < 0) { - piece.resize(-n_chars); - int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); - GGML_ASSERT(check == -n_chars); - } - else { - piece.resize(n_chars); +// checks if the weight tensor can be used with the specified buffer type and device +static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + GGML_ASSERT(w != nullptr); + + if (op == GGML_OP_NONE) { + return true; } - return piece; + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + switch (op) { + case GGML_OP_GET_ROWS: + { + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_get_rows(ctx, w, b); + } break; + case GGML_OP_MUL_MAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } break; + case GGML_OP_MUL_MAT_ID: + { + int n_expert_used = hparams.n_expert_used; + ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_mul_mat_id(ctx, w, b, ids); + } break; + case GGML_OP_ADD: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_add(ctx, a, w); + } break; + case GGML_OP_MUL: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_mul(ctx, a, w); + } break; + case GGML_OP_DIV: + { + ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); + op_tensor = ggml_div(ctx, a, w); + } break; + case GGML_OP_ROPE: + { + int n_embd_head = hparams.n_embd_head_v; + int n_head = hparams.n_head(); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_rope_ext( + ctx, a, b, w, + 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ); + + } break; + case GGML_OP_SSM_CONV: + { + // FIXME + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + op_tensor = ggml_ssm_conv(ctx, conv_x, w); + } break; + case GGML_OP_SSM_SCAN: + { + // FIXME + const int64_t d_state = w->ne[0]; + const int64_t d_inner = w->ne[1]; + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 1; + ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); + ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + } break; + case GGML_OP_RWKV_WKV6: + { + // FIXME + const int64_t S = 123; + const int64_t H = 123; + const int64_t n_tokens = 123; + const int64_t n_seqs = 123; + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * tf = w; + ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); + } break; + case GGML_OP_IM2COL: + { + const int n_embd = hparams.n_embd; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); + op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); + } break; + default: + GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + + return op_supported; } -void llm_load_stats(llama_model_loader & ml, llama_model & model) { - model.n_elements = ml.n_elements; - model.n_bytes = ml.n_bytes; +// lists of buffer types used for each layer +using buft_list_t = std::vector>; + +// find the first buffer type in the list that can use the tensor +static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t & buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { + return cur_buft; + } + } + return nullptr; } -void llm_load_arch(llama_model_loader & ml, llama_model & model) { - model.arch = ml.get_arch(); - if (model.arch == LLM_ARCH_UNKNOWN) { +// CPU: ACCEL -> CPU extra -> GPU host -> CPU +static buft_list_t make_cpu_buft_list(const std::vector & devices) { + buft_list_t buft_list; + + // add ACCEL buffer types + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + auto * buft = ggml_backend_dev_buffer_type(dev); + // skip + if (buft != ggml_backend_cpu_buffer_type()) { + buft_list.emplace_back(dev, buft); + } + } + } + + // add extra buffer types + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // add a host buffer type + // storing the tensors in a host buffer is useful when the processing of large batches + // is offloaded to a GPU device, since it reduces the time spent on data transfers + // generally, this will be done using the first device in the list + // a better approach would be to handle this on a weight-by-weight basis using the offload_op + // function of the device to determine if it would benefit from being stored in a host buffer + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } + } + + // add the CPU buffer type + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) { + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + } + } + + return buft_list; +} + +// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU +static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) { + buft_list_t buft_list; + + // add the device split buffer type if requested and available + if (split_mode == LLAMA_SPLIT_MODE_ROW) { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type"); + if (ggml_backend_split_buffer_type_fn) { + size_t dev_index = [&]() { + auto * reg = ggml_backend_dev_backend_reg(dev); + for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) { + if (ggml_backend_reg_dev_get(reg, i) == dev) { + return i; + } + } + throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev))); + }(); + auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split); + if (buft != nullptr) { + buft_list.emplace_back(dev, buft); + } + } + } + + // add the device default buffer type + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + + return buft_list; +} + +struct llama_model::impl { + impl() {} + ~impl() {} + + uint64_t n_elements = 0; + + size_t n_bytes = 0; + + std::string desc_str; + + // model memory mapped files + llama_mmaps mappings; + + // objects representing data potentially being locked in memory + llama_mlocks mlock_bufs; + llama_mlocks mlock_mmaps; + + // contexts where the model tensors metadata is stored + std::vector ctxs; + + // the model memory buffers for the tensor data + std::vector bufs; + + buft_list_t cpu_buft_list; + std::map gpu_buft_list; + + struct layer_dev { + ggml_backend_dev_t dev; + buft_list_t * buft_list; + }; + + layer_dev dev_input = {}; + layer_dev dev_output = {}; + std::vector dev_layer; +}; + +llama_model::llama_model(const struct llama_model_params & params) : params(params), pimpl(std::make_unique()) { +} + +llama_model::~llama_model() {} + +void llama_model::load_stats(llama_model_loader & ml) { + pimpl->n_elements = ml.n_elements; + pimpl->n_bytes = ml.n_bytes; +} + +void llama_model::load_arch(llama_model_loader & ml) { + arch = ml.get_arch(); + if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } } -void llm_load_hparams(llama_model_loader & ml, llama_model & model) { - auto & hparams = model.hparams; +void llama_model::load_hparams(llama_model_loader & ml) { const gguf_context * ctx = ml.meta.get(); // get metadata as string @@ -230,13 +396,11 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { } const char * name = gguf_get_key(ctx, i); const std::string value = gguf_kv_to_str(ctx, i); - model.gguf_kv.emplace(name, value); + gguf_kv.emplace(name, value); } // get general kv - ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); - - // get hparams kv + ml.get_key(LLM_KV_GENERAL_NAME, name, false); ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false); // everything past this point is not vocab-related @@ -249,8 +413,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false); - if (model.arch == LLM_ARCH_WAVTOKENIZER_DEC) { + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); @@ -274,8 +439,8 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1); - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false); // n_head_kv is optional, default to n_head @@ -325,7 +490,7 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) { + if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_MLLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -336,34 +501,36 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { hparams.n_embd_head_v = 0; } - using e_model = llm_type; // TMP + // for differentiating model types + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); // arch-specific KVs - switch (model.arch) { + switch (arch) { case LLM_ARCH_LLAMA: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (hparams.n_expert == 8) { switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_8x7B; break; - case 56: model.type = e_model::MODEL_8x22B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_8x7B; break; + case 56: type = LLM_TYPE_8x22B; break; + default: type = LLM_TYPE_UNKNOWN; } } else { switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_1B; break; // Llama 3.2 1B - case 22: model.type = e_model::MODEL_1B; break; - case 26: model.type = e_model::MODEL_3B; break; - case 28: model.type = e_model::MODEL_3B; break; // Llama 3.2 3B + case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B + case 22: type = LLM_TYPE_1B; break; + case 26: type = LLM_TYPE_3B; break; + case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B // granite uses a vocab with len 49152 - case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; - case 36: model.type = e_model::MODEL_8B; break; // granite - case 40: model.type = e_model::MODEL_13B; break; - case 48: model.type = e_model::MODEL_34B; break; - case 60: model.type = e_model::MODEL_30B; break; - case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; + case 36: type = LLM_TYPE_8B; break; // granite + case 40: type = LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_34B; break; + case 60: type = LLM_TYPE_30B; break; + case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } } break; @@ -372,42 +539,42 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_11B; break; - case 100: model.type = e_model::MODEL_90B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_11B; break; + case 100: type = LLM_TYPE_90B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DECI: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 80: model.type = e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_MINICPM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); switch (hparams.n_layer) { - case 52: model.type = e_model::MODEL_1B; break; - case 40: model.type = e_model::MODEL_2B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 52: type = LLM_TYPE_1B; break; + case 40: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_MINICPM3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); switch (hparams.n_layer) { - case 62: model.type = e_model::MODEL_4B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 62: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GROK: @@ -415,8 +582,8 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 64: model.type = e_model::MODEL_314B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 64: type = LLM_TYPE_314B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_FALCON: @@ -424,21 +591,21 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 60: model.type = e_model::MODEL_40B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 60: type = LLM_TYPE_40B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_BAICHUAN: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } - if (model.type == e_model::MODEL_13B) { + if (type == LLM_TYPE_13B) { // TODO: become GGUF KV parameter hparams.f_max_alibi_bias = 8.0f; } @@ -447,19 +614,19 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 36: model.type = e_model::MODEL_3B; break; - case 42: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_15B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + case 42: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_REFACT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_1B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; } // TODO: become GGUF KV parameter @@ -469,48 +636,45 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 3: - model.type = e_model::MODEL_17M; break; // bge-micro + type = LLM_TYPE_17M; break; // bge-micro case 6: - model.type = e_model::MODEL_22M; break; // MiniLM-L6 + type = LLM_TYPE_22M; break; // MiniLM-L6 case 12: switch (hparams.n_embd) { - case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small - case 768: model.type = e_model::MODEL_109M; break; // bge-base - default: model.type = e_model::MODEL_UNKNOWN; + case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small + case 768: type = LLM_TYPE_109M; break; // bge-base + default: type = LLM_TYPE_UNKNOWN; } break; case 24: - model.type = e_model::MODEL_335M; break; // bge-large - default: model.type = e_model::MODEL_UNKNOWN; + type = LLM_TYPE_335M; break; // bge-large + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { - case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small - case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base - default: model.type = e_model::MODEL_UNKNOWN; + case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small + case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_NOMIC_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); if (hparams.n_layer == 12 && hparams.n_embd == 768) { - model.type = e_model::MODEL_137M; + type = LLM_TYPE_137M; } } break; case LLM_ARCH_BLOOM: @@ -518,14 +682,14 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; + case 24: type = LLM_TYPE_1B; break; case 30: switch (hparams.n_embd) { - case 2560: model.type = e_model::MODEL_3B; break; - case 4096: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } // TODO: become GGUF KV parameter @@ -538,9 +702,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 48: model.type = e_model::MODEL_30B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_30B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_STABLELM: @@ -548,10 +712,10 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_3B; break; - case 40: model.type = e_model::MODEL_12B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_QWEN: @@ -559,9 +723,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_QWEN2VL: @@ -573,27 +737,27 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break; - case 28: model.type = hparams.n_embd == 1536 ? e_model::MODEL_1_5B : e_model::MODEL_7B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 36: model.type = e_model::MODEL_3B; break; - case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break; - case 48: model.type = e_model::MODEL_14B; break; - case 64: model.type = e_model::MODEL_32B; break; - case 80: model.type = e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; + case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; + case 32: type = LLM_TYPE_7B; break; + case 36: type = LLM_TYPE_3B; break; + case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_QWEN2MOE: { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_A2_7B; break; - case 28: model.type = e_model::MODEL_57B_A14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_A2_7B; break; + case 28: type = LLM_TYPE_57B_A14B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_PHI2: @@ -601,9 +765,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_PHI3: @@ -611,10 +775,10 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_3B; break; - case 40: model.type = e_model::MODEL_14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; } // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931 @@ -633,32 +797,41 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { throw std::runtime_error("invalid value for sliding_window"); } } break; + case LLM_ARCH_PHIMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_16x3_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_PLAMO: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 12: model.type = e_model::MODEL_SMALL; break; - case 24: model.type = e_model::MODEL_MEDIUM; break; - case 36: model.type = e_model::MODEL_LARGE; break; - case 48: model.type = e_model::MODEL_XL; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 12: type = LLM_TYPE_SMALL; break; + case 24: type = LLM_TYPE_MEDIUM; break; + case 36: type = LLM_TYPE_LARGE; break; + case 48: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_CODESHELL: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 42: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 42: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_ORION: @@ -666,17 +839,17 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_INTERNLM2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 48: model.type = e_model::MODEL_20B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GEMMA: @@ -684,37 +857,37 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 18: model.type = e_model::MODEL_2B; break; - case 28: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 18: type = LLM_TYPE_2B; break; + case 28: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GEMMA2: { hparams.n_swa = 4096; // default value of gemma 2 - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); hparams.attn_soft_cap = true; switch (hparams.n_layer) { - case 26: model.type = e_model::MODEL_2B; break; - case 42: model.type = e_model::MODEL_9B; break; - case 46: model.type = e_model::MODEL_27B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 26: type = LLM_TYPE_2B; break; + case 42: type = LLM_TYPE_9B; break; + case 46: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 30: model.type = e_model::MODEL_3B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_15B; break; - case 52: model.type = e_model::MODEL_20B; break; // granite - case 88: model.type = e_model::MODEL_34B; break; // granite - default: model.type = e_model::MODEL_UNKNOWN; + case 30: type = LLM_TYPE_3B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + case 52: type = LLM_TYPE_20B; break; // granite + case 88: type = LLM_TYPE_34B; break; // granite + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_MAMBA: @@ -730,51 +903,51 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { switch (hparams.n_layer) { case 24: switch (hparams.n_embd) { - case 768: model.type = e_model::MODEL_SMALL; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 48: switch (hparams.n_embd) { - case 1024: model.type = e_model::MODEL_MEDIUM; break; - case 1536: model.type = e_model::MODEL_LARGE; break; - case 2048: model.type = e_model::MODEL_XL; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 64: switch (hparams.n_embd) { - case 2560: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 2560: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - case 80: model.type = e_model::MODEL_65B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 80: type = LLM_TYPE_65B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_COMMAND_R: { - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_35B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_35B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_COHERE2: { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_8B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DBRX: @@ -783,8 +956,8 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_16x12B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OLMO: @@ -793,10 +966,10 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); switch (hparams.n_layer) { - case 22: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 80: model.type = e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 22: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OLMO2: @@ -804,18 +977,18 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 16: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OLMOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_A1_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OPENELM: @@ -823,57 +996,57 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_270M; break; - case 20: model.type = e_model::MODEL_450M; break; - case 28: model.type = e_model::MODEL_1B; break; - case 36: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GPTNEOX: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); switch (hparams.n_layer) { case 6: switch (hparams.n_ff()) { - case 512: model.type = e_model::MODEL_14M; break; - case 2048: model.type = e_model::MODEL_70M; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 512: type = LLM_TYPE_14M; break; + case 2048: type = LLM_TYPE_70M; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 12: switch (hparams.n_ff()) { - case 3072: model.type = e_model::MODEL_160M; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 3072: type = LLM_TYPE_160M; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 16: switch (hparams.n_ff()) { - case 8192: model.type = e_model::MODEL_1B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 8192: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 24: switch (hparams.n_ff()) { - case 4096: model.type = e_model::MODEL_410M; break; - case 8192: model.type = e_model::MODEL_1_4B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 4096: type = LLM_TYPE_410M; break; + case 8192: type = LLM_TYPE_1_4B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 32: switch (hparams.n_ff()) { - case 10240: model.type = e_model::MODEL_2_8B; break; - case 16384: model.type = e_model::MODEL_6_9B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 10240: type = LLM_TYPE_2_8B; break; + case 16384: type = LLM_TYPE_6_9B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 36: switch (hparams.n_ff()) { - case 20480: model.type = e_model::MODEL_12B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 20480: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 44: switch (hparams.n_ff()) { - case 24576: model.type = e_model::MODEL_20B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24576: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_ARCTIC: @@ -882,40 +1055,40 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { if (hparams.n_expert == 128) { switch (hparams.n_layer) { - case 35: model.type = e_model::MODEL_10B_128x3_66B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 35: type = LLM_TYPE_10B_128x3_66B; break; + default: type = LLM_TYPE_UNKNOWN; } } else { - model.type = e_model::MODEL_UNKNOWN; + type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DEEPSEEK: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); switch (hparams.n_layer) { - case 28: model.type = e_model::MODEL_20B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 28: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DEEPSEEK2: { bool is_lite = (hparams.n_layer == 27); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set @@ -924,19 +1097,31 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); switch (hparams.n_layer) { - case 27: model.type = e_model::MODEL_16B; break; - case 60: model.type = e_model::MODEL_236B; break; - case 61: model.type = e_model::MODEL_671B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 27: type = LLM_TYPE_16B; break; + case 60: type = LLM_TYPE_236B; break; + case 61: type = LLM_TYPE_671B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_CHATGLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 28: model.type = e_model::MODEL_6B; break; - case 40: model.type = e_model::MODEL_9B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 28: { + if (hparams.n_head(0) == 16) { + type = LLM_TYPE_1_5B; + } else { + type = LLM_TYPE_6B; + } + } break; + case 40: { + if (hparams.n_head(0) == 24) { + type = LLM_TYPE_4B; + } else { + type = LLM_TYPE_9B; + } + } break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_BITNET: @@ -944,13 +1129,13 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 26: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 26: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_T5: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); uint32_t dec_start_token_id; @@ -959,32 +1144,32 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { } switch (hparams.n_layer) { - case 6: model.type = e_model::MODEL_60M; break; // t5-small - case 8: model.type = e_model::MODEL_80M; break; // flan-t5-small + case 6: type = LLM_TYPE_60M; break; // t5-small + case 8: type = LLM_TYPE_80M; break; // flan-t5-small case 12: switch (hparams.n_ff()) { - case 3072: model.type = e_model::MODEL_220M; break; // t5-base - case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base - default: model.type = e_model::MODEL_UNKNOWN; + case 3072: type = LLM_TYPE_220M; break; // t5-base + case 2048: type = LLM_TYPE_250M; break; // flan-t5-base + default: type = LLM_TYPE_UNKNOWN; } break; case 24: switch (hparams.n_ff()) { - case 4096: model.type = e_model::MODEL_770M; break; // t5-large - case 2816: model.type = e_model::MODEL_780M; break; // flan-t5-large - case 16384: model.type = e_model::MODEL_3B; break; // t5-3b - case 5120: model.type = e_model::MODEL_3B; break; // flan-t5-xl - case 65536: model.type = e_model::MODEL_11B; break; // t5-11b - case 10240: model.type = e_model::MODEL_11B; break; // flan-t5-xxl - default: model.type = e_model::MODEL_UNKNOWN; + case 4096: type = LLM_TYPE_770M; break; // t5-large + case 2816: type = LLM_TYPE_780M; break; // flan-t5-large + case 16384: type = LLM_TYPE_3B; break; // t5-3b + case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl + case 65536: type = LLM_TYPE_11B; break; // t5-11b + case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_T5ENCODER: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); - model.type = e_model::MODEL_UNKNOWN; + type = LLM_TYPE_UNKNOWN; } break; case LLM_ARCH_JAIS: { @@ -992,18 +1177,18 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1_3B; break; - case 40: model.type = e_model::MODEL_13B; break; + case 24: type = LLM_TYPE_1_3B; break; + case 40: type = LLM_TYPE_13B; break; /* TODO: add variants */ - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_NEMOTRON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_4B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_EXAONE: @@ -1011,44 +1196,48 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_8B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); - ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); - ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); - ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1_6B; break; + case 24: type = LLM_TYPE_1_6B; break; case 32: switch (hparams.n_embd) { - case 2560: model.type = e_model::MODEL_3B; break; - case 4096: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - case 61: model.type = e_model::MODEL_14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_3B; break; - case 40: model.type = e_model::MODEL_3B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; // Add additional layer/vocab/etc checks here for other model sizes - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_CHAMELEON: @@ -1058,9 +1247,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 48: model.type = e_model::MODEL_34B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_34B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_SOLAR: @@ -1069,13 +1258,13 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { for (size_t i = 0; i < hparams.n_bskcn_arr.max_size(); ++i) { auto & bskcn = hparams.n_bskcn_arr[i]; bskcn.fill(0); - auto kv = LLM_KV(model.arch); + auto kv = LLM_KV(arch); ml.get_key_or_arr(format((kv(LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION) + ".%d").c_str(), i), bskcn, hparams.n_layer, false); } switch (hparams.n_layer) { - case 64: model.type = e_model::MODEL_22B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 64: type = LLM_TYPE_22B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_WAVTOKENIZER_DEC: @@ -1088,724 +1277,2426 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { default: throw std::runtime_error("unsupported model architecture"); } - model.ftype = ml.ftype; + pimpl->n_bytes = ml.n_bytes; + + pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); if (hparams.f_max_alibi_bias > 0.0f) { hparams.use_alibi = true; } - hparams.rope_type = llama_rope_type(&model); + hparams.rope_type = llama_model_rope_type(this); } -void llm_load_vocab(llama_model_loader & ml, llama_model & model) { - auto & vocab = model.vocab; +void llama_model::load_vocab(llama_model_loader & ml) { + const auto kv = LLM_KV(arch); - struct gguf_context * ctx = ml.meta.get(); + vocab.load(ml, kv); +} - const auto kv = LLM_KV(model.arch); +bool llama_model::load_tensors(llama_model_loader & ml) { + const auto & split_mode = params.split_mode; + const auto & n_gpu_layers = params.n_gpu_layers; + const auto & use_mlock = params.use_mlock; + const auto & tensor_split = params.tensor_split; - // determine vocab type - { - std::string tokenizer_model; - std::string tokenizer_pre; + const int n_layer = hparams.n_layer; - ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); - ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + const bool use_mmap_buffer = true; - if (tokenizer_model == "no_vocab" || tokenizer_model == "none") { - vocab.type = LLAMA_VOCAB_TYPE_NONE; + LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = LLAMA_TOKEN_NULL; - vocab.special_unk_id = LLAMA_TOKEN_NULL; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - vocab.linefeed_id = LLAMA_TOKEN_NULL; - - // read vocab size from metadata - if (!ml.get_key(LLM_KV_VOCAB_SIZE, vocab.n_vocab, false)) { - vocab.n_vocab = 0; - LLAMA_LOG_WARN("%s: there is no vocab_size in metadata, vocab.n_vocab will be set to %u\n", __func__, vocab.n_vocab); - } - return; - } - - if (tokenizer_model == "llama") { - vocab.type = LLAMA_VOCAB_TYPE_SPM; - - // default special tokens - vocab.special_bos_id = 1; - vocab.special_eos_id = 2; - vocab.special_unk_id = 0; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - } else if (tokenizer_model == "bert") { - vocab.type = LLAMA_VOCAB_TYPE_WPM; - - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = LLAMA_TOKEN_NULL; - vocab.special_unk_id = 100; - vocab.special_sep_id = 102; - vocab.special_pad_id = 0; - vocab.special_cls_id = 101; - vocab.special_mask_id = 103; - } else if (tokenizer_model == "gpt2") { - vocab.type = LLAMA_VOCAB_TYPE_BPE; - - // read bpe merges and populate bpe ranks - const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); - if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } - - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - - std::string first; - std::string second; - - const size_t pos = word.find(' ', 1); - - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); - } - - vocab.bpe_ranks.emplace(std::make_pair(first, second), i); - } - - // default special tokens - vocab.special_bos_id = 11; - vocab.special_eos_id = 11; - vocab.special_unk_id = LLAMA_TOKEN_NULL; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - } else if (tokenizer_model == "t5") { - vocab.type = LLAMA_VOCAB_TYPE_UGM; - - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = 1; - vocab.special_unk_id = 2; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = 0; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - - const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); - if (precompiled_charsmap_keyidx != -1) { - size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); - const char * precompiled_charsmap = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); - vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap); -#ifdef IS_BIG_ENDIAN - // correct endiannes of data in precompiled_charsmap binary blob - uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0]; - *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); - assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); - size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); - uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)]; - for (size_t i = 0; i < xcda_array_size; ++i) { - xcda_array[i] = __builtin_bswap32(xcda_array[i]); - } -#endif - } - } else if (tokenizer_model == "rwkv") { - vocab.type = LLAMA_VOCAB_TYPE_RWKV; - - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = LLAMA_TOKEN_NULL; - vocab.special_unk_id = LLAMA_TOKEN_NULL; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - } else { - throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); - } - - // for now, only BPE models have pre-tokenizers - if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { - vocab.tokenizer_add_space_prefix = false; - vocab.tokenizer_clean_spaces = true; - if (tokenizer_pre == "default") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } else if ( - tokenizer_pre == "llama3" || - tokenizer_pre == "llama-v3" || - tokenizer_pre == "llama-bpe"|| - tokenizer_pre == "falcon3") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; - vocab.tokenizer_ignore_merges = true; - vocab.tokenizer_add_bos = true; - } else if ( - tokenizer_pre == "deepseek-llm") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "deepseek-coder") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "deepseek-v3") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "falcon") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; - } else if ( - tokenizer_pre == "mpt") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT; - } else if ( - tokenizer_pre == "starcoder") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER; - } else if ( - tokenizer_pre == "gpt-2" || - tokenizer_pre == "phi-2" || - tokenizer_pre == "jina-es" || - tokenizer_pre == "jina-de" || - tokenizer_pre == "gigachat" || - tokenizer_pre == "jina-v1-en" || - tokenizer_pre == "jina-v2-es" || - tokenizer_pre == "jina-v2-de" || - tokenizer_pre == "jina-v2-code" || - tokenizer_pre == "roberta-bpe") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; - } else if ( - tokenizer_pre == "refact") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_REFACT; - } else if ( - tokenizer_pre == "command-r") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "qwen2") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "stablelm2") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2; - } else if ( - tokenizer_pre == "olmo") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO; - } else if ( - tokenizer_pre == "dbrx") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX; - } else if ( - tokenizer_pre == "smaug-bpe") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; - } else if ( - tokenizer_pre == "poro-chat") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "chatglm-bpe") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; - vocab.special_bos_id = LLAMA_TOKEN_NULL; - } else if ( - tokenizer_pre == "viking") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "jais") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS; - } else if ( - tokenizer_pre == "tekken") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN; - vocab.tokenizer_clean_spaces = false; - vocab.tokenizer_ignore_merges = true; - vocab.tokenizer_add_bos = true; - } else if ( - tokenizer_pre == "smollm") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "codeshell") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; - } else if ( - tokenizer_pre == "bloom") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BLOOM; - } else if ( - tokenizer_pre == "gpt3-finnish") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH; - } else if ( - tokenizer_pre == "exaone") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE; - } else if ( - tokenizer_pre == "chameleon") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; - vocab.tokenizer_add_bos = true; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "minerva-7b") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MINERVA; - } else if ( - tokenizer_pre == "megrez") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; - } else { - LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } - } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_space_prefix = true; - vocab.tokenizer_clean_spaces = false; - vocab.tokenizer_add_bos = true; - vocab.tokenizer_add_eos = false; - } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_space_prefix = false; - vocab.tokenizer_clean_spaces = true; - vocab.tokenizer_add_bos = true; - vocab.tokenizer_add_eos = false; - } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_bos = false; - vocab.tokenizer_add_eos = true; - } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_space_prefix = false; - vocab.tokenizer_clean_spaces = false; - vocab.tokenizer_add_bos = false; - vocab.tokenizer_add_eos = false; - } else { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } - - ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.tokenizer_add_space_prefix, false); - ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.tokenizer_remove_extra_whitespaces, false); + // build a list of buffer types for the CPU and GPU devices + pimpl->cpu_buft_list = make_cpu_buft_list(devices); + for (auto * dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); + // add CPU buffer types as a fallback + buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); + pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); } - const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); - if (token_idx == -1) { - throw std::runtime_error("cannot find tokenizer vocab in model file\n"); - } - - const float * scores = nullptr; - const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); - if (score_idx != -1) { - scores = (const float * ) gguf_get_arr_data(ctx, score_idx); - } - - const int * toktypes = nullptr; - const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); - if (toktype_idx != -1) { - toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); - } - - const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); - - vocab.n_vocab = n_vocab; - vocab.id_to_token.resize(n_vocab); - - for (uint32_t i = 0; i < n_vocab; i++) { - std::string word = gguf_get_arr_str(ctx, token_idx, i); - if (word.empty()) { - LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); - word = "[EMPTY_" + std::to_string(i) + "]"; + // calculate the split points + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); + std::vector splits(n_devices()); + if (all_zero) { + // default split, by free memory + for (size_t i = 0; i < n_devices(); ++i) { + ggml_backend_dev_t dev = devices[i]; + size_t total; + size_t free; + ggml_backend_dev_memory(dev, &free, &total); + splits[i] = free; } - - vocab.token_to_id[word] = i; - vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); - - auto & token_data = vocab.id_to_token[i]; - token_data.text = std::move(word); - token_data.score = scores ? scores[i] : 0.0f; - token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; - - if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file - switch(toktypes[i]) { - case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; - case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; - case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; - case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; - case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; - case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; - case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; - default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; - } - } - } - GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); - - vocab.init_tokenizer(); - - // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' - if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - try { - vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n'); - } catch (const std::exception & e) { - LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what()); - vocab.linefeed_id = vocab.special_pad_id; - } - } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { - vocab.linefeed_id = vocab.special_pad_id; - } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { - const std::vector ids = llama_tokenize_internal(vocab, "\n", false); - GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); - vocab.linefeed_id = ids[0]; } else { - const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A - - //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); - if (ids.empty()) { - LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); - vocab.linefeed_id = vocab.special_pad_id; - } else { - vocab.linefeed_id = ids[0]; - } + std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); } - // special tokens - { - const std::vector> special_token_types = { - { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, - { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, - { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, - { LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id }, - { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, - { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, - { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, - { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id }, - { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id }, - { LLM_KV_TOKENIZER_FIM_PRE_ID, vocab.special_fim_pre_id }, - { LLM_KV_TOKENIZER_FIM_SUF_ID, vocab.special_fim_suf_id }, - { LLM_KV_TOKENIZER_FIM_MID_ID, vocab.special_fim_mid_id }, - { LLM_KV_TOKENIZER_FIM_PAD_ID, vocab.special_fim_pad_id }, - { LLM_KV_TOKENIZER_FIM_REP_ID, vocab.special_fim_rep_id }, - { LLM_KV_TOKENIZER_FIM_SEP_ID, vocab.special_fim_sep_id }, + // sum and normalize the splits to get the split points + float split_sum = 0.0f; + for (size_t i = 0; i < n_devices(); ++i) { + split_sum += splits[i]; + splits[i] = split_sum; + } + for (size_t i = 0; i < n_devices(); ++i) { + splits[i] /= split_sum; + } - // deprecated - { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_fim_pre_id }, - { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_fim_suf_id }, - { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_fim_mid_id }, + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); + auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { + if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(cpu_dev)); + return {cpu_dev, &pimpl->cpu_buft_list}; + } + const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); + auto * dev = devices.at(layer_gpu); + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(dev)); + return {dev, &pimpl->gpu_buft_list.at(dev)}; + }; + + // assign the input layer + // there is very little benefit to offloading the input layer, so always keep it on the CPU + pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; + + // assign the repeating layers to the devices according to the splits + pimpl->dev_layer.resize(n_layer); + for (int il = 0; il < n_layer; ++il) { + pimpl->dev_layer[il] = get_layer_buft_list(il); + } + + // assign the output layer + pimpl->dev_output = get_layer_buft_list(n_layer); + + // one ggml context per buffer type + int max_n_tensors = ml.n_tensors; + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += n_layer*2; // duplicated rope freq tensors + const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; + + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ctx_map[buft] = ctx; + pimpl->ctxs.emplace_back(ctx); + + return ctx; + } + return it->second; + }; + + const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; + const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; + + // create tensors for the weights + { + // note: cast to int64_t since we will use these for the tensor dimensions + const int64_t n_head = hparams.n_head(); + const int64_t n_head_kv = hparams.n_head_kv(); + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_ff = hparams.n_ff(); + const int64_t n_embd_gqa = n_embd_v_gqa; + const int64_t n_vocab = hparams.n_vocab; + const int64_t n_token_types = vocab.n_token_types(); + const int64_t n_rot = hparams.n_rot; + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_ctx_train = hparams.n_ctx_train; + + if (n_expert > 0 && hparams.n_expert_used == 0) { + throw std::runtime_error("model has expert layers but no expert layers are used"); + } + + int n_moved_tensors = 0; + ggml_tensor * first_moved_tensor = nullptr; + ggml_backend_buffer_type_t first_moved_from_buft = nullptr; + ggml_backend_buffer_type_t first_moved_to_buft = nullptr; + + auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { + ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); + + if (!t_meta) { + if (flags & TENSOR_NOT_REQUIRED) { + return nullptr; + } + throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); + } + + // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops + // the tensor is duplicated + // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor + llm_tensor tn_tensor = tn.tensor; + if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & TENSOR_DUPLICATED) { + tn_tensor = LLM_TENSOR_OUTPUT; + } + + llm_tensor_info info; + try { + info = llm_tensor_info_for(tn_tensor); + } catch (const std::out_of_range & e) { + throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); + } + + // skip unused tensors + if (info.op == GGML_OP_NONE) { + LLAMA_LOG_WARN("model has unused tensor %s -- ignoring\n", tn.str().c_str()); + ml.n_created++; + + return nullptr; + } + + // tensors with "bias" suffix are always used with GGML_OP_ADD + ggml_op op; + bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; + if (bias) { + op = GGML_OP_ADD; + } else { + op = info.op; + } + + // sanity checks + if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (tn.bid != -1) { + GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); + } + } else { + if (tn.bid == -1) { + GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); + } + } + + // select the buffer type for this tensor + buft_list_t * buft_list; + switch (info.layer) { + case LLM_TENSOR_LAYER_INPUT: + buft_list = pimpl->dev_input.buft_list; + break; + case LLM_TENSOR_LAYER_OUTPUT: + buft_list = pimpl->dev_output.buft_list; + break; + case LLM_TENSOR_LAYER_REPEATING: + buft_list = pimpl->dev_layer.at(tn.bid).buft_list; + break; + default: + GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); + } + + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + } + + // avoid using a host buffer when using mmap + auto * buft_dev = ggml_backend_buft_get_device(buft); + if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + buft = ggml_backend_dev_buffer_type(cpu_dev); + } + + if (buft != buft_list->front().second) { + n_moved_tensors++; + if (!first_moved_tensor) { + first_moved_tensor = t_meta; + first_moved_from_buft = buft_list->front().second; + first_moved_to_buft = buft; + } + } + + ggml_context * ctx = ctx_for_buft(buft); + + // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one + if (flags & TENSOR_DUPLICATED) { + ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str()); + if (t) { + return t; + } + } + return ml.create_tensor(ctx, tn, ne, flags); }; - for (const auto & it : special_token_types) { - const std::string & key = kv(std::get<0>(it)); - int32_t & id = std::get<1>(it); + layers.resize(n_layer); - uint32_t new_id; - if (!ml.get_key(std::get<0>(it), new_id, false)) { - continue; - } - if (new_id >= vocab.id_to_token.size()) { - LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n", - __func__, key.c_str(), new_id, id); - } else { - id = new_id; - } - } + // TODO: move to a separate function + const auto tn = LLM_TN(arch); + switch (arch) { + case LLM_ARCH_LLAMA: + case LLM_ARCH_REFACT: + case LLM_ARCH_MINICPM: + case LLM_ARCH_GRANITE: + case LLM_ARCH_GRANITE_MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - // Handle add_bos_token and add_eos_token - { - bool temp = true; + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) { - vocab.tokenizer_add_bos = temp; - } - if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { - vocab.tokenizer_add_eos = temp; - } - } + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } - // auto-detect special tokens by text - // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... - // for now, we apply this workaround to find the tokens based on their text + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; - for (const auto & t : vocab.token_to_id) { - // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. - if (vocab.special_eot_id == LLAMA_TOKEN_NULL) { - if (false - || t.first == "<|eot_id|>" - || t.first == "<|im_end|>" - || t.first == "<|end|>" - || t.first == "" - || t.first == "<|endoftext|>" - || t.first == "" - || t.first == "<|end▁of▁sentence|>" // DeepSeek - ) { - vocab.special_eot_id = t.second; - if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { - LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", - __func__, t.second, t.first.c_str()); - vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } + } break; + case LLM_ARCH_MLLAMA: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + if (hparams.cross_attention_layers(i)) { + layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM, "weight", i), {128}, 0); + layer.cross_attn_k_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_PROJ, "weight", i), {n_embd, 1024}, 0); + layer.cross_attn_o_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_O_PROJ, "weight", i), {n_embd, n_embd}, 0); + layer.cross_attn_q_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}, 0); + layer.cross_attn_q_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}, 0); + layer.cross_attn_v_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}, 0); + layer.cross_attn_attn_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}, 0); + layer.cross_attn_mlp_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } else { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } + } break; + case LLM_ARCH_DECI: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_ff = hparams.n_ff(i); + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_kv = hparams.n_head_kv(i); + + if (n_head_kv == 0 && n_head > 0) { + // linear attention for DeciLMCausalModel + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + else if (n_head_kv > 0) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + } + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_MINICPM3: + { + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_GROK: + { + if (n_expert == 0) { + throw std::runtime_error("Grok model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_DBRX: + { + if (n_expert == 0) { + throw std::runtime_error("DBRX model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } break; + case LLM_ARCH_BAICHUAN: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_FALCON: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_STARCODER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + // needs to be on GPU + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_BERT: + case LLM_ARCH_NOMIC_BERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {1}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + if (arch == LLM_ARCH_BERT) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } else { + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + + if (arch == LLM_ARCH_BERT) { + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_JINA_BERT_V2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; // JinaBertLayer + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_BLOOM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_MPT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + // AWQ ScaleActivation layer + layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } break; + case LLM_ARCH_STABLELM: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors, present in Stable LM 2 1.6B + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + + // optional q and k layernorms, present in StableLM 2 12B + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + // optional FFN norm, not present in StableLM 2 12B which uses parallel residual + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0); + } + } break; + case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2VL: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_QWEN2MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } + } break; + case LLM_ARCH_PHI2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_PHI3: + { + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_PHIMOE: + { + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + } break; + case LLM_ARCH_PLAMO: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_GPT2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_CODESHELL: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } + } break; + case LLM_ARCH_ORION: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_INTERNLM2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_GEMMA: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } + } break; + case LLM_ARCH_GEMMA2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; + case LLM_ARCH_STARCODER2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional bias tensors + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0); + } + } break; + case LLM_ARCH_MAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + if (2 * n_embd != d_inner) { + throw std::runtime_error("only an expansion factor of 2 is supported for now"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } + } break; + case LLM_ARCH_XVERSE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_COMMAND_R: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (n_layer >= 64){ + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; + case LLM_ARCH_COHERE2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, + TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); } } - } + break; + case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - // find EOM token: "<|eom_id|>" - if (vocab.special_eom_id == LLAMA_TOKEN_NULL) { - if (false - || t.first == "<|eom_id|>" - ) { - vocab.special_eom_id = t.second; - if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { - LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", - __func__, t.second, t.first.c_str()); - vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - } - } - // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
-            if (vocab.special_fim_pre_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_prefix|>"  // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁begin|>" // DeepSeek
-                        || t.first == "
"
-                        ) {
-                    vocab.special_fim_pre_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
 
-            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
-            if (vocab.special_fim_suf_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_suffix|>" // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁hole|>" // DeepSeek
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_suf_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
-            if (vocab.special_fim_mid_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_middle|>" // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁end|>"  // DeepSeek
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_mid_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
-                }
-            }
+                } break;
+            case LLM_ARCH_OLMO2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
-            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
-            if (vocab.special_fim_pad_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_pad|>" // Qwen
-                        || t.first == ""
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_pad_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
-            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
-            if (vocab.special_fim_rep_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_repo|>"  // Qwen
-                        || t.first == "<|repo_name|>"
-                        || t.first == ""
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_rep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
-            }
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
 
-            // find FIM_SEP token: "<|file_sep|>"
-            if (vocab.special_fim_sep_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|file_sep|>" // Qwen
-                        ) {
-                    vocab.special_fim_sep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
                     }
-                }
-            }
+                } break;
+            case LLM_ARCH_OLMOE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                        if (n_expert == 0) {
+                            throw std::runtime_error("n_expert must be > 0");
+                        }
+                        if (n_expert_used == 0) {
+                            throw std::runtime_error("n_expert_used must be > 0");
+                        }
+
+                        // MoE branch
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                    }
+                } break;
+            case LLM_ARCH_OPENELM:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    // init output from the input tok embed
+                    output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        const int64_t n_head      =   hparams.n_head(i);
+                        const int64_t n_head_qkv  = 2*hparams.n_head_kv(i) + n_head;
+                        const int64_t n_ff        =   hparams.n_ff(i);
+
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_GPTNEOX:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_ARCTIC:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                        layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
+                    }
+                } break;
+            case LLM_ARCH_DEEPSEEK:
+                {
+
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (i < (int) hparams.n_layer_dense_lead) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        } else {
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
+
+                            // MoE branch
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                            // Shared expert branch
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                        }
+                    }
+                } break;
+            case LLM_ARCH_DEEPSEEK2:
+                {
+                    const bool is_lite = (hparams.n_layer == 27);
+
+                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
+                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
+                    const int64_t q_lora_rank  = hparams.n_lora_q;
+                    const int64_t kv_lora_rank = hparams.n_lora_kv;
+
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        if (!is_lite) {
+                            layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
+                        }
+
+                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
+
+                        if (!is_lite) {
+                            layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
+                            layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
+                        } else {
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        }
+
+                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
+                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
+                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (i < (int) hparams.n_layer_dense_lead) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        } else {
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
+
+                            // MoE branch
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                            // Shared expert branch
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                        }
+                    }
+                } break;
+            case LLM_ARCH_BITNET:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm     = create_tensor(tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd}, 0);
+                        layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq       = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.wk       = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.wv       = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.wo       = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm     = create_tensor(tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd}, 0);
+                        layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0);
+
+                        layer.ffn_gate       = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down       = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up         = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_scale   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                    }
+                } break;
+            case LLM_ARCH_T5:
+                {
+                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm     = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        layer.attn_norm  = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.attn_norm_cross  = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        // this tensor seems to be unused in HF transformers implementation
+                        layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_T5ENCODER:
+                {
+                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_JAIS:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_CHATGLM:
+                {
+                    tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        if (layer.wqkv == nullptr) {
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        }
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_NEMOTRON:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        // optional bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
+
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        // optional MLP bias
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
+                    }
+                } break;
+            case LLM_ARCH_EXAONE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM,   "weight", i), {n_embd}, 0);
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN,   "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,     "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_RWKV6:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // Block 0, LN0
+                    tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+                    const int head_size = hparams.wkv_head_size;
+                    const int attn_hidden_size = n_embd;
+                    const int ffn_size = hparams.n_ff_arr[0];
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, 0);
+
+                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
+                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
+
+                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
+
+                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
+                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
+                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
+                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
+                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
+
+                        layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
+                        layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
+                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+                        layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
+
+                        layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
+                        layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
+                        layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0);
+                    }
+
+                } break;
+            case LLM_ARCH_RWKV6QWEN2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+                    const int head_size = hparams.wkv_head_size;
+                    const int attn_hidden_size = n_embd;
+                    const int n_head_kv = hparams.n_head_kv();
+                    int attn_key_value_size;
+                    if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) {
+                        attn_key_value_size = attn_hidden_size;
+                    } else {
+                        attn_key_value_size = n_head_kv * head_size;
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
+                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
+
+                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
+
+                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
+                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
+                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
+                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0);
+                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0);
+                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        // optional bias tensors
+                        layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_CHAMELEON:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
+                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i),  {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i),  {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_SOLAR:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    {
+                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.bskcn_tv = create_tensor(tn(LLM_TENSOR_BSKCN_TV, "weight", i), {2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_WAVTOKENIZER_DEC:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0);
+
+                    conv1d   = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0);
+                    conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"),   {1, hparams.posnet.n_embd}, 0);
+
+                    // posnet
+                    {
+                        const int64_t n_embd = hparams.posnet.n_embd;
+
+                        for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) {
+                            auto & layer = layers[i].posnet;
+
+                            // posnet:
+                            //
+                            //  - resnet
+                            //  - resnet
+                            //  - attn
+                            //  - resnet
+                            //  - resnet
+                            //  - norm
+                            //
+                            switch (i) {
+                                case 0:
+                                case 1:
+                                case 3:
+                                case 4:
+                                    {
+                                        layer.norm1   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0);
+                                        layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.conv1   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0);
+                                        layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.norm2   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0);
+                                        layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.conv2   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0);
+                                        layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                case 2:
+                                    {
+                                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
+                                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_q      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_q_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_k      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_k_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_v      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_v_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_o      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_o_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                case 5:
+                                    {
+                                        layer.norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
+                                        layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                default: GGML_ABORT("unknown posnet layer");
+                            };
+                        }
+                    }
+
+                    GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);
+
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {hparams.posnet.n_embd}, 0);
+
+                    // convnext
+                    {
+                        const int64_t n_embd = hparams.convnext.n_embd;
+
+                        for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) {
+                            auto & layer = layers[i].convnext;
+
+                            layer.dw     = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "weight", i), {7, 1, n_embd}, 0);
+                            layer.dw_b   = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "bias",   i), {1, n_embd}, 0);
+
+                            layer.norm   = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "weight", i), {n_embd}, 0);
+                            layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "bias",   i), {n_embd}, 0);
+
+                            layer.pw1    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "weight", i), {n_embd, n_ff}, 0);
+                            layer.pw1_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "bias",   i), {n_ff}, 0);
+
+                            layer.pw2    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "weight", i), {n_ff, n_embd}, 0);
+                            layer.pw2_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "bias",   i), {n_embd}, 0);
+
+                            layer.gamma  = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0);
+                        }
+
+                        // output
+                        output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    }
+
+                    output   = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
+                    output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"),   {n_embd}, 0);
+                } break;
+            default:
+                throw std::runtime_error("unknown architecture");
         }
 
-        // maintain a list of tokens that cause end-of-generation
-        // this is currently determined based on the token text, which is obviously not ideal
-        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
-        vocab.special_eog_ids.clear();
-
-        if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
-        }
-
-        if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
-        }
-
-        if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
-        }
-
-        for (const auto & t : vocab.token_to_id) {
-            if (false
-                    || t.first == "<|eot_id|>"
-                    || t.first == "<|im_end|>"
-                    || t.first == "<|end|>"
-                    || t.first == ""
-                    || t.first == "<|endoftext|>"
-                    || t.first == "<|eom_id|>"
-                    || t.first == ""
-               ) {
-                vocab.special_eog_ids.insert(t.second);
-                if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                            __func__, t.second, t.first.c_str());
-                    vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                }
-            } else {
-                // token is control, but not marked as EOG -> print a debug log
-                if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
-                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
-                            __func__, t.second, t.first.c_str());
-                }
-            }
-        }
-
-        // sanity checks
-        if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eos_id);
-            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
-        }
-
-        if (vocab.special_eot_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eot_id);
-            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
-        }
-
-        if (vocab.special_eom_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eom_id);
-            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        if (n_moved_tensors > 0) {
+            LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n",
+                __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1,
+                ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft));
         }
     }
 
-    // build special tokens cache
-    {
-        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
-                vocab.cache_special_tokens.push_back(id);
-            }
+    ml.done_getting_tensors();
+
+    ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr);
+    pimpl->mappings.reserve(ml.mappings.size());
+
+    // create the backend buffers
+    std::vector> ctx_bufs;
+    ctx_bufs.reserve(ctx_map.size());
+
+    // Ensure we have enough capacity for the maximum backend buffer we will potentially create
+    const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
+    pimpl->bufs.reserve(n_max_backend_buffer);
+
+    for (auto & it : ctx_map) {
+        ggml_backend_buffer_type_t buft = it.first;
+        ggml_context * ctx              = it.second;
+
+        // skip contexts without tensors
+        if (ggml_get_first_tensor(ctx) == nullptr) {
+            continue;
         }
 
-        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
-            [&] (const llama_vocab::id a, const llama_vocab::id b) {
-                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
-            }
-        );
+        llama_buf_map buf_map;
+        buf_map.reserve(n_max_backend_buffer);
 
-        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
-    }
-
-    // build token to piece cache
-    {
-        size_t size_cache = 0;
-
-        std::vector cache_token_to_piece(n_vocab);
-
-        for (uint32_t id = 0; id < n_vocab; ++id) {
-            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
-
-            size_cache += cache_token_to_piece[id].size();
+        // check if it is possible to use buffer_from_host_ptr with this buffer type
+        ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
+        if (!dev) {
+            // FIXME: workaround for CPU backend buft having a NULL device
+            dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
         }
+        ggml_backend_dev_props props;
+        ggml_backend_dev_get_props(dev, &props);
+        bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
+        bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
 
-        std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
-
-        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
-    }
-
-    // Handle per token attributes
-    //NOTE: Each model customizes per token attributes.
-    //NOTE: Per token attributes are missing from the GGUF file.
-    //TODO: Extract attributes from GGUF file.
-    {
-        auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool {
-            for (auto substr : substrs) {
-                if (str.find(substr) < std::string::npos) {
-                    return true;
+        if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
+            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
+                // only the mmap region containing the tensors in the model is mapped to the backend buffer
+                // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
+                // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
+                void * addr = nullptr;
+                size_t first, last; // NOLINT
+                ml.get_mapping_range(&first, &last, &addr, idx, ctx);
+                if (first >= last) {
+                    continue;
                 }
+                const size_t max_size = ggml_get_max_tensor_size(ctx);
+                ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
+                if (buf == nullptr) {
+                    throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
+                }
+                pimpl->bufs.emplace_back(buf);
+                buf_map.emplace(idx, buf);
             }
+        }
+        else {
+            ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+            if (buf == nullptr) {
+                throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
+            }
+            pimpl->bufs.emplace_back(buf);
+            if (use_mlock && ggml_backend_buffer_is_host(buf)) {
+                pimpl->mlock_bufs.emplace_back(new llama_mlock);
+                auto & mlock_buf = pimpl->mlock_bufs.back();
+                mlock_buf->init   (ggml_backend_buffer_get_base(buf));
+                mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
+            }
+            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
+                buf_map.emplace(idx, buf);
+            }
+        }
+
+        if (pimpl->bufs.empty()) {
+            throw std::runtime_error("failed to allocate buffer");
+        }
+
+        for (auto & buf : buf_map) {
+            // indicate that this buffer contains weights
+            // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight
+            ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+        }
+
+        ctx_bufs.emplace_back(ctx, buf_map);
+    }
+
+    if (llama_supports_gpu_offload()) {
+        const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
+
+        LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
+        if (n_gpu_layers > (int) hparams.n_layer) {
+            LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__);
+        }
+
+        const int max_backend_supported_layers = hparams.n_layer + 1;
+        const int max_offloadable_layers       = hparams.n_layer + 1;
+
+        LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
+    }
+
+    // print memory requirements per buffer type
+    for (auto & buf : pimpl->bufs) {
+        LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
+    }
+
+    // populate tensors_by_name
+    for (auto & ctx : pimpl->ctxs) {
+        for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
+            tensors_by_name.emplace_back(ggml_get_name(cur), cur);
+        }
+    }
+
+    // load tensor data
+    for (auto & it : ctx_bufs) {
+        ggml_context * ctx = it.first;
+        auto & bufs = it.second;
+        if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
             return false;
-        };
-
-        auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
-            uint32_t current = vocab.id_to_token.at(id).attr;
-            current = value ? (current | attr) : (current & ~attr);
-            vocab.id_to_token[id].attr = (llama_token_attr) current;
-        };
-
-        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
-            _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
-        };
-
-        std::string model_name;
-        std::string tokenizer_pre;
-
-        ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
-        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
-
-        // model name to lowercase
-        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
-            [] (const std::string::value_type x) {
-                return std::tolower(x);
-            }
-        );
-
-        // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
-            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
-        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
-            for (auto id : vocab.cache_special_tokens) {
-                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {""}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {"", "", "<|endoftext|>"}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
-            }
         }
     }
+
+    if (use_mmap_buffer) {
+        for (auto & mapping : ml.mappings) {
+            pimpl->mappings.emplace_back(std::move(mapping));
+        }
+    }
+
+    return true;
 }
 
-void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
-    const auto & hparams = model.hparams;
-    const auto & vocab   = model.vocab;
+std::string llama_model::arch_name() const {
+    return llm_arch_name(arch);
+}
 
+std::string llama_model::type_name() const {
+    return llm_type_name(type);
+}
+
+std::string llama_model::desc() const {
+    return pimpl->desc_str;
+}
+
+size_t llama_model::size() const {
+    return pimpl->n_bytes;
+}
+
+size_t llama_model::max_nodes() const {
+    return std::max(8192, tensors_by_name.size()*5);
+}
+
+size_t llama_model::n_devices() const {
+    return devices.size();
+}
+
+uint64_t llama_model::n_elements() const {
+    return pimpl->n_elements;
+}
+
+void llama_model::print_info() const {
     const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
 
     auto print_f = [](const std::function & f, uint32_t n) {
@@ -1838,11 +3729,7 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     };
 
     // hparams
-    LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
-    LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, llm_arch_name(model.arch));
-    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
-    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
-    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
+    LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, arch_name().c_str());
     LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
 
     if (!hparams.vocab_only) {
@@ -1881,60 +3768,28 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
     }
 
-    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model).c_str());
-    LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model).c_str());
-    if (ml.n_elements >= 1e12) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f T\n", __func__, ml.n_elements*1e-12);
-    } else if (ml.n_elements >= 1e9) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9);
-    } else if (ml.n_elements >= 1e6) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f M\n", __func__, ml.n_elements*1e-6);
+    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, type_name().c_str());
+    if (pimpl->n_elements >= 1e12) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f T\n", __func__, pimpl->n_elements*1e-12);
+    } else if (pimpl->n_elements >= 1e9) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, pimpl->n_elements*1e-9);
+    } else if (pimpl->n_elements >= 1e6) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f M\n", __func__, pimpl->n_elements*1e-6);
     } else {
-        LLAMA_LOG_INFO("%s: model params     = %.2f K\n", __func__, ml.n_elements*1e-3);
-    }
-    if (ml.n_bytes < GiB) {
-        LLAMA_LOG_INFO("%s: model size       = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0,        ml.n_bytes*8.0/ml.n_elements);
-    } else {
-        LLAMA_LOG_INFO("%s: model size       = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
+        LLAMA_LOG_INFO("%s: model params     = %.2f K\n", __func__, pimpl->n_elements*1e-3);
     }
 
     // general kv
-    LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
+    LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, name.c_str());
 
-    // special tokens
-    if (vocab.special_bos_id  != -1)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,     vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
-    if (vocab.special_eos_id  != -1)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,     vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
-    if (vocab.special_eot_id  != -1)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,     vocab.id_to_token[vocab.special_eot_id].text.c_str() );  }
-    if (vocab.special_eom_id  != -1)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, vocab.special_eom_id,     vocab.id_to_token[vocab.special_eom_id].text.c_str() );  }
-    if (vocab.special_unk_id  != -1)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,     vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
-    if (vocab.special_sep_id  != -1)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,     vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
-    if (vocab.special_pad_id  != -1)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,     vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
-    if (vocab.special_cls_id  != -1)    { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,     vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
-    if (vocab.special_mask_id != -1)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id,    vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
-
-    if (vocab.linefeed_id != -1)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,        vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
-
-    if (vocab.special_fim_pre_id != -1) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token[vocab.special_fim_pre_id].text.c_str() ); }
-    if (vocab.special_fim_suf_id != -1) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token[vocab.special_fim_suf_id].text.c_str() ); }
-    if (vocab.special_fim_mid_id != -1) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token[vocab.special_fim_mid_id].text.c_str() ); }
-    if (vocab.special_fim_pad_id != -1) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token[vocab.special_fim_pad_id].text.c_str() ); }
-    if (vocab.special_fim_rep_id != -1) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token[vocab.special_fim_rep_id].text.c_str() ); }
-    if (vocab.special_fim_sep_id != -1) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token[vocab.special_fim_sep_id].text.c_str() ); }
-
-    for (const auto & id : vocab.special_eog_ids) {
-        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
-    }
-
-    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
-
-    if (model.arch == LLM_ARCH_DEEPSEEK) {
+    if (arch == LLM_ARCH_DEEPSEEK) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
         LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
     }
 
-    if (model.arch == LLM_ARCH_DEEPSEEK2) {
+    if (arch == LLM_ARCH_DEEPSEEK2) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
         LLAMA_LOG_INFO("%s: n_lora_kv            = %d\n",     __func__, hparams.n_lora_kv);
@@ -1946,16 +3801,88 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
     }
 
-    if (model.arch == LLM_ARCH_QWEN2MOE) {
+    if (arch == LLM_ARCH_QWEN2MOE) {
         LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
     }
 
-    if (model.arch == LLM_ARCH_MINICPM || model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
+    if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) {
         LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
         LLAMA_LOG_INFO("%s: f_residual_scale  = %f\n", __func__, hparams.f_residual_scale);
         LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
     }
+
+    vocab.print_info();
+}
+
+ggml_backend_dev_t llama_model::dev_layer(int il) const {
+    return pimpl->dev_layer.at(il).dev;
+}
+
+ggml_backend_dev_t llama_model::dev_output() const {
+    return pimpl->dev_output.dev;
+}
+
+template
+static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
+    ggml_init_params params = {
+        /*.mem_size   =*/ ggml_tensor_overhead()*8,
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true,
+    };
+
+    ggml_context_ptr ctx { ggml_init(params) };
+    if (!ctx) {
+        throw std::runtime_error(format("failed to create ggml context"));
+    }
+
+    ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
+    ggml_tensor * op_tensor = fn(ctx.get());
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (op_tensor->src[i] != nullptr) {
+            assert(op_tensor->src[i]->buffer == nullptr);
+            op_tensor->src[i]->buffer = buf.get();
+        }
+    }
+
+    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
+
+    return op_supported;
+}
+
+template
+static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) {
+    for (const auto & cur : buft_list) {
+        ggml_backend_dev_t cur_dev = cur.first;
+        ggml_backend_buffer_type_t cur_buft = cur.second;
+        if (buft_supported(cur_buft, cur_dev, fn)) {
+            return cur_buft;
+        }
+    }
+
+    throw std::runtime_error(format("no suitable buffer type found"));
+}
+
+ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
+    return ::select_buft(
+            *pimpl->dev_layer.at(il).buft_list,
+            [&](ggml_context * ctx) {
+                ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
+                ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
+                return ggml_add(ctx, cur, layer_dir);
+            });
+}
+
+const struct ggml_tensor * llama_model::get_tensor(const char * name) const {
+    auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
+            [name](const std::pair & it) {
+                return it.first == name;
+            });
+    if (it == tensors_by_name.end()) {
+        return nullptr;
+    }
+
+    return it->second;
 }
 
 //
@@ -1969,7 +3896,6 @@ struct llama_model_params llama_model_default_params() {
         /*.split_mode                  =*/ LLAMA_SPLIT_MODE_LAYER,
         /*.main_gpu                    =*/ 0,
         /*.tensor_split                =*/ nullptr,
-        /*.rpc_servers                 =*/ nullptr,
         /*.progress_callback           =*/ nullptr,
         /*.progress_callback_user_data =*/ nullptr,
         /*.kv_overrides                =*/ nullptr,
@@ -1987,35 +3913,59 @@ struct llama_model_params llama_model_default_params() {
     return result;
 }
 
+const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model) {
+    return &model->vocab;
+}
+
 void llama_free_model(struct llama_model * model) {
+    llama_model_free(model);
+}
+
+void llama_model_free(struct llama_model * model) {
     delete model;
 }
 
-enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
-    return model->vocab.type;
-}
-
-int32_t llama_n_vocab(const struct llama_model * model) {
-    return model->hparams.n_vocab;
-}
-
-int32_t llama_n_ctx_train(const struct llama_model * model) {
+int32_t llama_model_n_ctx_train(const struct llama_model * model) {
     return model->hparams.n_ctx_train;
 }
 
-int32_t llama_n_embd(const struct llama_model * model) {
+int32_t llama_model_n_embd(const struct llama_model * model) {
     return model->hparams.n_embd;
 }
 
-int32_t llama_n_layer(const struct llama_model * model) {
+int32_t llama_model_n_layer(const struct llama_model * model) {
     return model->hparams.n_layer;
 }
 
-int32_t llama_n_head(const struct llama_model * model) {
+int32_t llama_model_n_head(const struct llama_model * model) {
     return model->hparams.n_head();
 }
 
-enum llama_rope_type llama_rope_type(const struct llama_model * model) {
+int32_t llama_model_n_head_kv(const struct llama_model * model) {
+    return model->hparams.n_head_kv();
+}
+
+// deprecated
+int32_t llama_n_ctx_train(const struct llama_model * model) {
+    return llama_model_n_ctx_train(model);
+}
+
+// deprecated
+int32_t llama_n_embd(const struct llama_model * model) {
+    return llama_model_n_embd(model);
+}
+
+// deprecated
+int32_t llama_n_layer(const struct llama_model * model) {
+    return llama_model_n_layer(model);
+}
+
+// deprecated
+int32_t llama_n_head(const struct llama_model * model) {
+    return llama_model_n_head(model);
+}
+
+enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
     switch (model->arch) {
         // these models do not use RoPE
         case LLM_ARCH_GPT2:
@@ -2029,6 +3979,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_T5ENCODER:
         case LLM_ARCH_JAIS:
         case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
         case LLM_ARCH_WAVTOKENIZER_DEC:
             return LLAMA_ROPE_TYPE_NONE;
 
@@ -2071,6 +4022,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_OLMOE:
         case LLM_ARCH_PHI2:
         case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_STARCODER2:
@@ -2093,7 +4045,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
     return LLAMA_ROPE_TYPE_NONE;
 }
 
-float llama_rope_freq_scale_train(const struct llama_model * model) {
+float llama_model_rope_freq_scale_train(const struct llama_model * model) {
     return model->hparams.rope_freq_scale_train;
 }
 
@@ -2137,18 +4089,26 @@ int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int3
 }
 
 int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
-    return snprintf(buf, buf_size, "%s %s %s",
-            llama_model_arch_name (*model).c_str(),
-            llama_model_type_name (*model).c_str(),
-            llama_model_ftype_name(*model).c_str());
+    return snprintf(buf, buf_size, "%s", model->desc().c_str());
 }
 
 uint64_t llama_model_size(const struct llama_model * model) {
-    return model->n_bytes;
+    return model->size();
+}
+
+const char * llama_model_chat_template(const struct llama_model * model, const char * name) {
+    const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N)
+        : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE);
+    const auto & it = model->gguf_kv.find(key);
+    if (it == model->gguf_kv.end()) {
+        return nullptr;
+    }
+
+    return it->second.c_str();
 }
 
 uint64_t llama_model_n_params(const struct llama_model * model) {
-    return model->n_elements;
+    return model->n_elements();
 }
 
 bool llama_model_has_encoder(const struct llama_model * model) {
@@ -2174,6 +4134,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
     switch (model->arch) {
         case LLM_ARCH_MAMBA:  return true;
         case LLM_ARCH_RWKV6:  return true;
+        case LLM_ARCH_RWKV6QWEN2: return true;
         default:              return false;
     }
 }
diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h
index 5b23e2ba..7cf57587 100644
--- a/llama/llama.cpp/src/llama-model.h
+++ b/llama/llama.cpp/src/llama-model.h
@@ -4,81 +4,83 @@
 #include "llama-arch.h"
 #include "llama-hparams.h"
 #include "llama-vocab.h"
-#include "llama-mmap.h"
-
-#include "ggml-cpp.h"
 
+#include 
+#include 
+#include 
 #include 
 #include 
 
+struct llama_model_loader;
+
 // available models
-// TODO: this enum does not follow the enum naming convention
 enum llm_type {
-    MODEL_UNKNOWN,
-    MODEL_14M,
-    MODEL_17M,
-    MODEL_22M,
-    MODEL_33M,
-    MODEL_60M,
-    MODEL_70M,
-    MODEL_80M,
-    MODEL_109M,
-    MODEL_137M,
-    MODEL_160M,
-    MODEL_220M,
-    MODEL_250M,
-    MODEL_270M,
-    MODEL_335M,
-    MODEL_410M,
-    MODEL_450M,
-    MODEL_770M,
-    MODEL_780M,
-    MODEL_0_5B,
-    MODEL_1B,
-    MODEL_1_3B,
-    MODEL_1_4B,
-    MODEL_1_5B,
-    MODEL_1_6B,
-    MODEL_2B,
-    MODEL_2_8B,
-    MODEL_3B,
-    MODEL_4B,
-    MODEL_6B,
-    MODEL_6_9B,
-    MODEL_7B,
-    MODEL_8B,
-    MODEL_9B,
-    MODEL_11B,
-    MODEL_12B,
-    MODEL_13B,
-    MODEL_14B,
-    MODEL_15B,
-    MODEL_16B,
-    MODEL_20B,
-    MODEL_22B,
-    MODEL_30B,
-    MODEL_32B,
-    MODEL_34B,
-    MODEL_35B,
-    MODEL_40B,
-    MODEL_65B,
-    MODEL_70B,
-    MODEL_90B,
-    MODEL_236B,
-    MODEL_314B,
-    MODEL_671B,
-    MODEL_SMALL,
-    MODEL_MEDIUM,
-    MODEL_LARGE,
-    MODEL_XL,
-    MODEL_A1_7B,
-    MODEL_A2_7B,
-    MODEL_8x7B,
-    MODEL_8x22B,
-    MODEL_16x12B,
-    MODEL_10B_128x3_66B,
-    MODEL_57B_A14B,
-    MODEL_27B,
+    LLM_TYPE_UNKNOWN,
+    LLM_TYPE_14M,
+    LLM_TYPE_17M,
+    LLM_TYPE_22M,
+    LLM_TYPE_33M,
+    LLM_TYPE_60M,
+    LLM_TYPE_70M,
+    LLM_TYPE_80M,
+    LLM_TYPE_109M,
+    LLM_TYPE_137M,
+    LLM_TYPE_160M,
+    LLM_TYPE_220M,
+    LLM_TYPE_250M,
+    LLM_TYPE_270M,
+    LLM_TYPE_335M,
+    LLM_TYPE_410M,
+    LLM_TYPE_450M,
+    LLM_TYPE_770M,
+    LLM_TYPE_780M,
+    LLM_TYPE_0_5B,
+    LLM_TYPE_1B,
+    LLM_TYPE_1_3B,
+    LLM_TYPE_1_4B,
+    LLM_TYPE_1_5B,
+    LLM_TYPE_1_6B,
+    LLM_TYPE_2B,
+    LLM_TYPE_2_8B,
+    LLM_TYPE_3B,
+    LLM_TYPE_4B,
+    LLM_TYPE_6B,
+    LLM_TYPE_6_9B,
+    LLM_TYPE_7B,
+    LLM_TYPE_8B,
+    LLM_TYPE_9B,
+    LLM_TYPE_11B,
+    LLM_TYPE_12B,
+    LLM_TYPE_13B,
+    LLM_TYPE_14B,
+    LLM_TYPE_15B,
+    LLM_TYPE_16B,
+    LLM_TYPE_20B,
+    LLM_TYPE_22B,
+    LLM_TYPE_30B,
+    LLM_TYPE_32B,
+    LLM_TYPE_34B,
+    LLM_TYPE_35B,
+    LLM_TYPE_40B,
+    LLM_TYPE_65B,
+    LLM_TYPE_70B,
+    LLM_TYPE_90B,
+    LLM_TYPE_236B,
+    LLM_TYPE_314B,
+    LLM_TYPE_671B,
+    LLM_TYPE_SMALL,
+    LLM_TYPE_MEDIUM,
+    LLM_TYPE_LARGE,
+    LLM_TYPE_XL,
+    LLM_TYPE_A1_7B,
+    LLM_TYPE_A2_7B,
+    LLM_TYPE_8x7B,
+    LLM_TYPE_8x22B,
+    LLM_TYPE_16x12B,
+    LLM_TYPE_16x3_8B,
+    LLM_TYPE_10B_128x3_66B,
+    LLM_TYPE_57B_A14B,
+    LLM_TYPE_27B,
 };
 
 struct llama_layer_posnet {
@@ -243,15 +245,19 @@ struct llama_layer {
     struct ggml_tensor * time_mix_lerp_v     = nullptr;
     struct ggml_tensor * time_mix_lerp_r     = nullptr;
     struct ggml_tensor * time_mix_lerp_g     = nullptr;
+    struct ggml_tensor * time_mix_lerp_fused = nullptr;
 
-    struct ggml_tensor * time_mix_first      = nullptr;
-    struct ggml_tensor * time_mix_decay      = nullptr;
-    struct ggml_tensor * time_mix_decay_w1   = nullptr;
-    struct ggml_tensor * time_mix_decay_w2   = nullptr;
-    struct ggml_tensor * time_mix_key        = nullptr;
-    struct ggml_tensor * time_mix_value      = nullptr;
-    struct ggml_tensor * time_mix_receptance = nullptr;
-    struct ggml_tensor * time_mix_gate       = nullptr;
+    struct ggml_tensor * time_mix_first        = nullptr;
+    struct ggml_tensor * time_mix_decay        = nullptr;
+    struct ggml_tensor * time_mix_decay_w1     = nullptr;
+    struct ggml_tensor * time_mix_decay_w2     = nullptr;
+    struct ggml_tensor * time_mix_key          = nullptr;
+    struct ggml_tensor * time_mix_key_b        = nullptr;
+    struct ggml_tensor * time_mix_value        = nullptr;
+    struct ggml_tensor * time_mix_value_b      = nullptr;
+    struct ggml_tensor * time_mix_receptance   = nullptr;
+    struct ggml_tensor * time_mix_receptance_b = nullptr;
+    struct ggml_tensor * time_mix_gate         = nullptr;
 
     struct ggml_tensor * time_mix_ln     = nullptr;
     struct ggml_tensor * time_mix_ln_b   = nullptr;
@@ -280,7 +286,7 @@ struct llama_layer {
 
     struct ggml_tensor * bskcn_tv = nullptr;
 
-     // cross attention
+    // cross attention
     struct ggml_tensor * cross_attn_k_norm = nullptr;
     struct ggml_tensor * cross_attn_k_proj = nullptr;
     struct ggml_tensor * cross_attn_o_proj = nullptr;
@@ -296,11 +302,9 @@ struct llama_layer {
 };
 
 struct llama_model {
-    llm_type type = MODEL_UNKNOWN;
+    llm_type type = LLM_TYPE_UNKNOWN;
     llm_arch arch = LLM_ARCH_UNKNOWN;
 
-    llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
-
     std::string name = "n/a";
 
     llama_hparams hparams = {};
@@ -329,117 +333,53 @@ struct llama_model {
 
     std::vector layers;
 
+    llama_model_params params;
+
     // gguf metadata
     std::unordered_map gguf_kv;
 
-    llama_split_mode split_mode;
-    int main_gpu;
-    int n_gpu_layers;
-
-    std::vector rpc_servers;
-
     // list of devices used in this model
     std::vector devices;
 
-
-    // lists of buffer types used for each layer
-    using buft_list_t = std::vector>;
-    buft_list_t cpu_buft_list;
-    std::map gpu_buft_list;
-
-    struct layer_dev {
-        ggml_backend_dev_t dev;
-        buft_list_t * buft_list;
-    };
-
-    layer_dev dev_input = {};
-    layer_dev dev_output = {};
-    std::vector dev_layer;
-
-    // contexts where the model tensors metadata is stored
-    std::vector ctxs;
-
-    // the model memory buffers for the tensor data
-    std::vector bufs;
-
-    // model memory mapped files
-    llama_mmaps mappings;
-
-    // objects representing data potentially being locked in memory
-    llama_mlocks mlock_bufs;
-    llama_mlocks mlock_mmaps;
-
     // for quantize-stats only
     std::vector> tensors_by_name;
 
     int64_t t_load_us  = 0;
     int64_t t_start_us = 0;
 
-    // total number of parameters in the model
-    uint64_t n_elements = 0;
+    explicit llama_model(const struct llama_model_params & params);
+    ~llama_model();
 
-    // total size of all the tensors in the model in bytes
-    size_t  n_bytes     = 0;
+    void load_stats  (llama_model_loader & ml);
+    void load_arch   (llama_model_loader & ml);
+    void load_hparams(llama_model_loader & ml);
+    void load_vocab  (llama_model_loader & ml);
+    bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
+
+    std::string arch_name() const;
+    std::string type_name() const;
+
+    std::string desc() const;
+
+    size_t size() const;
+    size_t max_nodes() const;
+    size_t n_devices() const;
+
+    // total number of parameters in the model
+    uint64_t n_elements() const;
+
+    void print_info() const;
+
+    ggml_backend_dev_t dev_layer(int il) const;
+    ggml_backend_dev_t dev_output() const;
+
+    ggml_backend_buffer_type_t select_buft(int il) const;
+
+    const struct ggml_tensor * get_tensor(const char * name) const;
+
+private:
+    struct impl;
+    std::unique_ptr pimpl;
 };
 
 const char * llm_type_name(llm_type type);
-
-std::string llama_model_arch_name (const llama_model & model);
-std::string llama_model_type_name (const llama_model & model);
-std::string llama_model_ftype_name(const llama_model & model);
-
-template
-bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
-    ggml_init_params params = {
-        /*.mem_size   =*/ ggml_tensor_overhead()*8,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-
-    ggml_context_ptr ctx { ggml_init(params) };
-    if (!ctx) {
-        throw std::runtime_error("failed to create ggml context");
-    }
-
-    ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
-    ggml_tensor * op_tensor = fn(ctx.get());
-    for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (op_tensor->src[i] != nullptr) {
-            op_tensor->src[i]->buffer = buf.get();
-        }
-    }
-
-    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
-
-    return op_supported;
-}
-
-template
-ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) {
-    for (const auto & cur : buft_list) {
-        ggml_backend_dev_t cur_dev = cur.first;
-        ggml_backend_buffer_type_t cur_buft = cur.second;
-        if (buft_supported(cur_buft, cur_dev, fn)) {
-            return cur_buft;
-        }
-    }
-
-    throw std::runtime_error("no suitable buffer type found");
-}
-
-// used by llama_adapter_cvec
-ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
-
-// used by llama_adapter_lora
-struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name);
-
-size_t llama_model_max_nodes(const llama_model & model);
-
-struct llama_model_loader;
-
-// TODO: become llama_model methods
-void llm_load_stats     (llama_model_loader & ml, llama_model & model);
-void llm_load_arch      (llama_model_loader & ml, llama_model & model);
-void llm_load_hparams   (llama_model_loader & ml, llama_model & model);
-void llm_load_vocab     (llama_model_loader & ml, llama_model & model);
-void llm_load_print_meta(llama_model_loader & ml, llama_model & model);
diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp
index 27def6fd..6eb1da08 100644
--- a/llama/llama.cpp/src/llama-quant.cpp
+++ b/llama/llama.cpp/src/llama-quant.cpp
@@ -7,14 +7,12 @@
 #include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
 #include 
 
-// TODO: replace with ggml API call
-#define QK_K 256
-
 static void zeros(std::ofstream & file, size_t n) {
     char zero = 0;
     for (size_t i = 0; i < n; ++i) {
@@ -22,7 +20,7 @@ static void zeros(std::ofstream & file, size_t n) {
     }
 }
 
-struct quantize_state_internal {
+struct quantize_state_impl {
     const llama_model                 & model;
     const llama_model_quantize_params * params;
 
@@ -43,13 +41,13 @@ struct quantize_state_internal {
     // used to figure out if a model shares tok_embd with the output weight
     bool has_output = false;
 
-    quantize_state_internal(const llama_model & model, const llama_model_quantize_params * params)
+    quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params)
         : model(model)
         , params(params)
         {}
 };
 
-static void llama_tensor_dequantize_internal(
+static void llama_tensor_dequantize_impl(
     struct ggml_tensor * tensor, std::vector> & output, std::vector & workers,
     const size_t nelements, const int nthread
 ) {
@@ -121,7 +119,7 @@ static void llama_tensor_dequantize_internal(
     workers.clear();
 }
 
-static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
+static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
     const std::string name = ggml_get_name(tensor);
 
     // TODO: avoid hardcoded tensor names - use the TN_* constants
@@ -154,8 +152,10 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
         if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
             new_type = qs.params->output_tensor_type;
         } else {
-            int nx = tensor->ne[0];
-            if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
+            const int64_t nx = tensor->ne[0];
+            const int64_t qk_k = ggml_blck_size(new_type);
+
+            if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
                 new_type = GGML_TYPE_Q8_0;
             }
             else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
@@ -235,7 +235,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
         else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
                 use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
         else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
-        if (qs.model.type == MODEL_70B) {
+        if (qs.model.type == LLM_TYPE_70B) {
             // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
             // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
             // nearly negligible increase in model size by quantizing this tensor with more bits:
@@ -367,20 +367,19 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
     //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
     //}
     bool convert_incompatible_tensor = false;
-    if (new_type == GGML_TYPE_Q2_K    || new_type == GGML_TYPE_Q3_K    || new_type == GGML_TYPE_Q4_K   ||
-        new_type == GGML_TYPE_Q5_K    || new_type == GGML_TYPE_Q6_K    || new_type == GGML_TYPE_IQ4_XS ||
-        new_type == GGML_TYPE_IQ2_XS  || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S  ||
-        new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S   || new_type == GGML_TYPE_IQ3_S  ||
-        new_type == GGML_TYPE_IQ1_M) {
-        int nx = tensor->ne[0];
-        int ny = tensor->ne[1];
-        if (nx % QK_K != 0) {
-            LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));
+    {
+        const int64_t nx = tensor->ne[0];
+        const int64_t ny = tensor->ne[1];
+        const int64_t qk_k = ggml_blck_size(new_type);
+
+        if (nx % qk_k != 0) {
+            LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
             convert_incompatible_tensor = true;
         } else {
             ++qs.n_k_quantized;
         }
     }
+
     if (convert_incompatible_tensor) {
         switch (new_type) {
             case GGML_TYPE_TQ1_0:
@@ -410,7 +409,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
     return new_type;
 }
 
-static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) {
+static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) {
     if (nthread < 2) {
         // single-thread
         size_t new_size = ggml_quantize_chunk(new_type, f32_data, new_data, 0, nrows, n_per_row, imatrix);
@@ -464,7 +463,7 @@ static size_t llama_tensor_quantize_internal(enum ggml_type new_type, const floa
     return new_size;
 }
 
-static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
+static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
     ggml_type default_type;
     llama_ftype ftype = params->ftype;
 
@@ -526,18 +525,21 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         auto v = (std::vector*)params->kv_overrides;
         kv_overrides = v->data();
     }
-    llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
+
+    std::vector splits = {};
+    llama_model_loader ml(fname_inp, splits, use_mmap, /*check_tensors*/ true, kv_overrides);
     ml.init_mappings(false); // no prefetching
 
-    llama_model model;
-    llm_load_arch   (ml, model);
-    llm_load_hparams(ml, model);
-    llm_load_stats  (ml, model);
+    llama_model model(llama_model_default_params());
 
-    struct quantize_state_internal qs(model, params);
+    model.load_arch   (ml);
+    model.load_hparams(ml);
+    model.load_stats  (ml);
+
+    struct quantize_state_impl qs(model, params);
 
     if (params->only_copy) {
-        ftype = model.ftype;
+        ftype = ml.ftype;
     }
     const std::unordered_map> * imatrix_data = nullptr;
     if (params->imatrix) {
@@ -621,7 +623,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
     qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
 
-    // sanity checks
+    // sanity checks for models that have attention layers
+    if (qs.n_attention_wv != 0)
     {
         const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
         // attention layers have a non-zero number of kv heads
@@ -761,6 +764,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
         quantize &= name.find("time_mix_w2.weight") == std::string::npos;
         quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
         quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
+        quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
 
         // do not quantize relative position bias (T5)
         quantize &= name.find("attn_rel_b.weight") == std::string::npos;
@@ -839,7 +843,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
             } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) {
                 throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type)));
             } else {
-                llama_tensor_dequantize_internal(tensor, f32_conv_buf, workers, nelements, nthread);
+                llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread);
                 f32_data = (float *) f32_conv_buf.data();
             }
 
@@ -868,7 +872,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
                 void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows;
                 const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
 
-                new_size += llama_tensor_quantize_internal(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
+                new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
             }
             LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
         }
@@ -877,7 +881,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
 
         // update the gguf meta data as we go
         gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
-        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data, new_size);
+        GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
+        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
 
         // write tensor data + padding
         fout.write((const char *) new_data, new_size);
@@ -921,7 +926,7 @@ uint32_t llama_model_quantize(
         const char * fname_out,
         const llama_model_quantize_params * params) {
     try {
-        llama_model_quantize_internal(fname_inp, fname_out, params);
+        llama_model_quantize_impl(fname_inp, fname_out, params);
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what());
         return 1;
diff --git a/llama/llama.cpp/src/llama-sampling.cpp b/llama/llama.cpp/src/llama-sampling.cpp
index 69cea2f1..f40bf2db 100644
--- a/llama/llama.cpp/src/llama-sampling.cpp
+++ b/llama/llama.cpp/src/llama-sampling.cpp
@@ -257,7 +257,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
             for (int i = 0; i < (int)cur_p->size; ++i) {
                 const float val = cur_p->data[i].logit;
                 int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
-                ib = std::max(0, std::min(nbuckets-1, ib));
+                ib = std::max(0, std::min(nbuckets - 1, ib));
                 bucket_idx[i] = ib;
                 ++histo[ib];
             }
@@ -280,13 +280,13 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
             for (int i = 0; i < (int)cur_p->size; ++i) {
                 int j = bucket_idx[i];
                 if (j >= ib) {
-                    *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
+                    *bucket_ptrs[nbuckets - 1 - j]++ = cur_p->data[i];
                 }
             }
 
             ptr = tmp_tokens.data();
             int ndone = 0;
-            for (int j = nbuckets-1; j > ib; --j) {
+            for (int j = nbuckets - 1; j > ib; --j) {
                 std::sort(ptr, ptr + histo[j], comp);
                 ptr += histo[j];
                 ndone += histo[j];
@@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {
 
 // llama_sampler API
 
+struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
+    return new llama_sampler {
+        /* .iface = */ iface,
+        /* .ctx   = */ ctx,
+    };
+}
+
 const char * llama_sampler_name(const struct llama_sampler * smpl) {
     if (!smpl->iface) {
         return "(null)";
@@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
     }
 
     if (smpl->ctx == nullptr) {
-        return new llama_sampler {
+        return llama_sampler_init(
             /* .iface = */ smpl->iface,
-            /* .ctx   = */ nullptr,
-        };
+            /* .ctx   = */ nullptr
+        );
     }
 
     GGML_ABORT("the sampler does not support cloning");
@@ -371,7 +378,10 @@ void llama_sampler_free(struct llama_sampler * smpl) {
 llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
     const auto * logits = llama_get_logits_ith(ctx, idx);
 
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    const int n_vocab = llama_vocab_n_tokens(vocab);
 
     // TODO: do not allocate each time
     std::vector cur;
@@ -469,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
 };
 
 struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_chain_i,
         /* .ctx   = */ new llama_sampler_chain {
             /* .params      = */ params,
             /* .samplers    = */ {},
             /* .t_sample_us = */ 0,
             /* .n_sample    = */ 0,
-        },
-    };
+        }
+    );
 }
 
 void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
@@ -543,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
 };
 
 struct llama_sampler * llama_sampler_init_greedy() {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_greedy_i,
-        /* .ctx   = */ nullptr,
-    };
+        /* .ctx   = */ nullptr
+    );
 }
 
 // dist
@@ -605,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {
 
 struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_dist_i,
         /* .ctx   = */ new llama_sampler_dist {
             /* .seed     = */ seed,
             /* .seed_cur = */ seed_cur,
             /* .rng      = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // softmax
@@ -635,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
 };
 
 struct llama_sampler * llama_sampler_init_softmax() {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_softmax_i,
-        /* .ctx   = */ nullptr,
-    };
+        /* .ctx   = */ nullptr
+    );
 }
 
 // top-k
@@ -675,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
 };
 
 struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_top_k_i,
         /* .ctx   = */ new llama_sampler_top_k {
             /* .k = */ k,
-        },
-    };
+        }
+    );
 }
 
 // top-p
@@ -741,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
 };
 
 struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_top_p_i,
         /* .ctx   = */ new llama_sampler_top_p {
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
-        },
-    };
+        }
+    );
 }
 
 // min-p
@@ -837,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
 };
 
 struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_min_p_i,
         /* .ctx   = */ new llama_sampler_min_p {
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
-        },
-    };
+        }
+    );
 }
 
 // typical
@@ -936,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
 };
 
 struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_typical_i,
         /* .ctx   = */ new llama_sampler_typical {
             /* .p        = */ p,
             /* .min_keep = */ min_keep,
-        },
-    };
+        }
+    );
 }
 
 // temp
@@ -980,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
 };
 
 struct llama_sampler * llama_sampler_init_temp(float temp) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_temp_i,
         /* .ctx   = */ new llama_sampler_temp {
             /*.temp = */ temp,
-        },
-    };
+        }
+    );
 }
 
 // temp-ext
@@ -1090,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
 };
 
 struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_temp_ext_i,
         /* .ctx   = */ new llama_sampler_temp_ext {
             /* .temp     = */ temp,
             /* .delta    = */ delta,
             /* .exponent = */ exponent,
-        },
-    };
+        }
+    );
 }
 
 // xtc
@@ -1182,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
 
 struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_xtc_i,
         /* .ctx   = */ new llama_sampler_xtc {
             /* .probability   = */ p,
@@ -1191,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
             /* .seed          = */ seed,
             /* .seed_cur      = */ seed_cur,
             /* .rng           = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // mirostat
@@ -1289,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
 
 struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_mirostat_i,
         /* .ctx   = */ new llama_sampler_mirostat {
             /* .n_vocab  = */ n_vocab,
@@ -1300,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
             /* .m        = */ m,
             /* .mu       = */ 2.0f*tau,
             /* .rng      = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // mirostat v2
@@ -1388,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
 
 struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
     auto seed_cur = get_rng_seed(seed);
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_mirostat_v2_i,
         /* .ctx   = */ new llama_sampler_mirostat_v2 {
             /* .seed     = */ seed,
@@ -1397,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
             /* .eta      = */ eta,
             /* .mu       = */ 2.0f*tau,
             /* .rng      = */ std::mt19937(seed_cur),
-        },
-    };
+        }
+    );
 }
 
 // grammar
@@ -1430,13 +1440,30 @@ static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token
     }
 }
 
+// Fwd declare to break reset --> init_impl --> llama_sampler_grammar_i --> reset cycle.
+static struct llama_sampler * llama_sampler_init_grammar_impl(
+        const struct llama_vocab * vocab,
+                      const char * grammar_str,
+                      const char * grammar_root,
+                              bool lazy,
+                     const char ** trigger_words,
+                            size_t num_trigger_words,
+               const llama_token * trigger_tokens,
+                            size_t num_trigger_tokens);
+
 static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
     auto * ctx = (llama_sampler_grammar *) smpl->ctx;
     if (!ctx->grammar) {
         return;
     }
 
-    auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
+    std::vector  trigger_words;
+    for (auto & word : ctx->grammar->trigger_words) {
+        trigger_words.push_back(word.c_str());
+    }
+    auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),
+                                                 ctx->grammar->lazy, trigger_words.data(), trigger_words.size(),
+                                                 ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size());
 
     llama_grammar_free_impl(ctx->grammar);
     ctx->grammar = grammar_new;
@@ -1445,7 +1472,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
 static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
 
-    auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
+    auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0);
 
     // copy the state
     {
@@ -1481,29 +1508,55 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
     /* .free   = */ llama_sampler_grammar_free,
 };
 
-struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
+static struct llama_sampler * llama_sampler_init_grammar_impl(
+        const struct llama_vocab * vocab,
+                      const char * grammar_str,
+                      const char * grammar_root,
+                              bool lazy,
+                     const char ** trigger_words,
+                            size_t num_trigger_words,
+               const llama_token * trigger_tokens,
+                            size_t num_trigger_tokens) {
     auto * ctx = new llama_sampler_grammar;
 
     if (grammar_str != nullptr && grammar_str[0] != '\0') {
         *ctx = {
-            /* .vocab        = */ &vocab,
+            /* .vocab        = */ vocab,
             /* .grammar_str  = */ grammar_str,
             /* .grammar_root = */ grammar_root,
-            /* .grammar      = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
+            /* .grammar      = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens),
         };
     } else {
         *ctx = {
-            /* .vocab        = */ &vocab,
+            /* .vocab        = */ vocab,
             /* .grammar_str  = */ {},
             /* .grammar_root = */ {},
             /* .grammar      = */ nullptr,
         };
     }
 
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_grammar_i,
-        /* .ctx   = */ ctx,
-    };
+        /* .ctx   = */ ctx
+    );
+}
+
+struct llama_sampler * llama_sampler_init_grammar(
+        const struct llama_vocab * vocab,
+                      const char * grammar_str,
+                      const char * grammar_root) {
+    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0);
+}
+
+struct llama_sampler * llama_sampler_init_grammar_lazy(
+        const struct llama_vocab * vocab,
+                      const char * grammar_str,
+                      const char * grammar_root,
+                     const char ** trigger_words,
+                            size_t num_trigger_words,
+               const llama_token * trigger_tokens,
+                            size_t num_trigger_tokens) {
+    return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens);
 }
 
 // penalties
@@ -1632,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties(
         float penalty_present) {
     penalty_last_n = std::max(penalty_last_n, 0);
 
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_penalties_i,
         /* .ctx   = */ new llama_sampler_penalties {
             /* .penalty_last_n  = */ penalty_last_n,
@@ -1641,8 +1694,75 @@ struct llama_sampler * llama_sampler_init_penalties(
             /* .penalty_present = */ penalty_present,
             /* .prev            = */ ring_buffer(penalty_last_n),
             /* .token_count     = */ {},
-        },
-    };
+        }
+    );
+}
+
+// top-n-sigma
+
+struct llama_sampler_top_n_sigma {
+    const float n;
+};
+
+static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
+    return "top-n-sigma";
+}
+
+static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
+
+    // find max logit and calculate mean
+    float max = cur_p->data[0].logit;
+    float logits_sum = 0;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (cur_p->data[i].logit > max) {
+            max = cur_p->data[i].logit;
+        }
+        logits_sum += cur_p->data[i].logit;
+    }
+    float mean = logits_sum/cur_p->size;
+
+    // calculate standard deviation
+    float acc = 0;
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        acc += pow(cur_p->data[i].logit - mean, 2);
+    }
+    float std = sqrt(acc/cur_p->size);
+
+    //apply mask
+    for (size_t i = 0; i < cur_p->size; ++i) {
+        if (cur_p->data[i].logit < max - (ctx->n * std)) {
+            cur_p->data[i].logit = -INFINITY;
+        }
+    }
+    llama_sampler_softmax_impl(cur_p);
+}
+
+static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
+    const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
+    return llama_sampler_init_top_n_sigma(ctx->n);
+}
+
+static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
+    delete (llama_sampler_top_n_sigma *) smpl->ctx;
+}
+
+static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
+    /* .name   = */ llama_sampler_top_n_sigma_name,
+    /* .accept = */ nullptr,
+    /* .apply  = */ llama_sampler_top_n_sigma_apply,
+    /* .reset  = */ nullptr,
+    /* .clone  = */ llama_sampler_top_n_sigma_clone,
+    /* .free   = */ llama_sampler_top_n_sigma_free,
+};
+
+struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
+    return llama_sampler_init(
+        /* .iface = */ &llama_sampler_top_n_sigma_i,
+        /* .ctx   = */ new llama_sampler_top_n_sigma {
+            /* .n = */ n,
+        }
+    );
 }
 
 // DRY
@@ -1663,8 +1783,8 @@ struct llama_sampler_dry {
 
 // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
 static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) {
-    for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
-        std::string word = llama_detokenize(vocab, {token_id}, true);
+    for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
+        std::string word = vocab.detokenize({token_id}, true);
         if (word.find(str) != std::string::npos) {
             token_sequences.emplace(token_id, std::vector());
         } else {
@@ -1681,7 +1801,7 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
                     }
                 }
                 if (match) {
-                    std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
+                    std::vector tokenization = vocab.tokenize(str.substr(i), false, false);
                     if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
                         tokenization.resize(max_tail_len);
                     }
@@ -1832,7 +1952,7 @@ static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_dat
                 ctx->dry_repeat_count[last - k] = std::min(n, rep_limit);
                 if (n > 0) {
                     lt = k;
-                    rt = k+n-1;
+                    rt = k + n - 1;
                 }
             } else {
                 // If k is inside the current Z-box, consider two cases.
@@ -1937,7 +2057,7 @@ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler
     llama_vocab dummy_vocab;
 
     // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
-    auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
+    auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
 
     // Copy the state, including the processed breakers
     {
@@ -1964,7 +2084,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
     /* .free   = */ llama_sampler_dry_free,
 };
 
-struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
+struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
     int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
     std::unordered_multimap> processed_breakers;
     const int MAX_CHAR_LEN = 40;
@@ -1991,11 +2111,11 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
                 sequence_break.resize(MAX_CHAR_LEN);
             }
 
-            get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
+            get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
         }
     }
 
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_dry_i,
         /* .ctx   = */ new llama_sampler_dry {
             /* .total_context_size     = */ context_size,
@@ -2007,14 +2127,14 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
             /* .dry_repeat_count       = */ dry_enabled ? std::vector(effective_dry_penalty_last_n, 0) : std::vector{},
             /* .dry_max_token_repeat   = */ {},
             /* .last_tokens            = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0),
-        },
-    };
+        }
+    );
 }
 
 // wrapper for test-sampling.cpp
 struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers) {
     llama_vocab dummy_vocab;
-    auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
+    auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
     auto * ctx = (llama_sampler_dry *) result->ctx;
 
     // Process the token-based sequence breakers
@@ -2109,14 +2229,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
                          int32_t   n_vocab,
                          int32_t   n_logit_bias,
           const llama_logit_bias * logit_bias) {
-    return new llama_sampler {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_logit_bias_i,
         /* .ctx   = */ new llama_sampler_logit_bias {
             /* .n_vocab    = */ n_vocab,
             /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias),
             /* .to_search  = */ {},
-        },
-    };
+        }
+    );
 }
 
 // infill
@@ -2153,7 +2273,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     float p_eog_sum = 0.0f;
 
     for (size_t i = 0; i < cur_p->size; ++i) {
-        if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
+        if (ctx->vocab->is_eog(cur_p->data[i].id)) {
             p_eog_sum += cur_p->data[i].p;
         } else {
             p_txt_sum += cur_p->data[i].p;
@@ -2175,7 +2295,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
         float p_sum = 0.0f;
 
         for (size_t i = 0; i < size_org; ++i) {
-            if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
+            if (ctx->vocab->is_eog(cur_p->data[i].id)) {
                 p_sum += cur_p->data[i].p;
 
                 cur_p->data[cur_p->size++] = cur_p->data[i];
@@ -2203,17 +2323,17 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
                 continue;
             }
 
-            int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+            int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
             if (len0 < 0) {
                 ctx->buf0.resize(len0);
-                len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+                len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
                 assert(len0 > 0);
             }
 
-            int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+            int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
             if (len1 < 0) {
                 ctx->buf1.resize(len1);
-                len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+                len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
                 assert(len1 > 0);
             }
 
@@ -2248,7 +2368,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
 
     for (size_t i = 0; i < size_org; ++i) {
-        const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
+        const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
 
         if (cur_p->data[i].p < thold && !is_eog) {
             continue;
@@ -2269,7 +2389,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     // if no non-EOG tokens are left -> reduce cur_p to single EOT token
     if (n_non_eog == 0) {
         cur_p->size = 1;
-        cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
+        cur_p->data[0].id = ctx->vocab->token_eot();
         cur_p->data[0].logit = 1.0f;
 
         return;
@@ -2291,7 +2411,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
 
     for (size_t i = 0; i < size_org; ++i) {
-        const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
+        const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
 
         if (cur_p->data[i].p < thold && !is_eog) {
             continue;
@@ -2314,7 +2434,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
 
 static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
-    return llama_sampler_init_infill_impl(*ctx->vocab);
+    return llama_sampler_init_infill(ctx->vocab);
 }
 
 static void llama_sampler_infill_free(struct llama_sampler * smpl) {
@@ -2330,16 +2450,15 @@ static struct llama_sampler_i llama_sampler_infill_i = {
     /* .free   = */ llama_sampler_infill_free,
 };
 
-struct llama_sampler * llama_sampler_init_infill_impl(
-        const struct llama_vocab & vocab) {
-    return new llama_sampler {
+struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
+    return llama_sampler_init(
         /* .iface = */ &llama_sampler_infill_i,
         /* .ctx   = */ new llama_sampler_infill {
-            /* .vocab = */ &vocab,
-            /* .buf0 = */ std::vector(512),
-            /* .buf1 = */ std::vector(512),
-        },
-    };
+            /* .vocab = */ vocab,
+            /* .buf0  = */ std::vector(512),
+            /* .buf1  = */ std::vector(512),
+        }
+    );
 }
 
 // utils
diff --git a/llama/llama.cpp/src/llama-sampling.h b/llama/llama.cpp/src/llama-sampling.h
index 919f6fdf..759dd7dc 100644
--- a/llama/llama.cpp/src/llama-sampling.h
+++ b/llama/llama.cpp/src/llama-sampling.h
@@ -2,7 +2,9 @@
 
 // TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
 
-#include "llama-grammar.h"
+#include "llama.h"
+
+#include 
 
 struct llama_vocab;
 struct llama_grammar;
@@ -21,24 +23,6 @@ struct llama_sampler_chain {
     mutable int32_t n_sample;
 };
 
-struct llama_sampler * llama_sampler_init_grammar_impl(
-        const struct llama_vocab & vocab,
-                      const char * grammar_str,
-                      const char * grammar_root);
-
-struct llama_sampler * llama_sampler_init_infill_impl(
-        const struct llama_vocab & vocab);
-
-struct llama_sampler * llama_sampler_init_dry_impl(
-        const struct llama_vocab &  vocab,
-                         int32_t    context_size,
-                           float    dry_multiplier,
-                           float    dry_base,
-                         int32_t    dry_allowed_length,
-                         int32_t    dry_penalty_last_n,
-                      const char ** seq_breakers,
-                          size_t    num_breakers);
-
 struct llama_sampler * llama_sampler_init_dry_testing(
                          int32_t   context_size,
                            float   dry_multiplier,
diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp
index 8f44705a..1ca827eb 100644
--- a/llama/llama.cpp/src/llama-vocab.cpp
+++ b/llama/llama.cpp/src/llama-vocab.cpp
@@ -1,6 +1,7 @@
 #include "llama-vocab.h"
 
 #include "llama-impl.h"
+#include "llama-model-loader.h"
 
 #include "unicode.h"
 
@@ -11,8 +12,10 @@
 #include 
 #include 
 #include 
+#include 
 #include 
-#include 
+#include 
+#include 
 
 //
 // helpers
@@ -62,96 +65,14 @@ struct naive_trie {
 };
 
 //
-// impl
+// tokenizers
 //
 
 struct llm_tokenizer {
-   llm_tokenizer() {}
-   virtual ~llm_tokenizer() = default;
+    llm_tokenizer() {}
+    virtual ~llm_tokenizer() = default;
 };
 
-llama_vocab::~llama_vocab() {
-    delete tokenizer;
-}
-
-int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
-    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
-    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
-    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
-    GGML_ASSERT(token_right.find('\n') == std::string::npos);
-
-    auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
-    if (it == bpe_ranks.end()) {
-        return -1;
-    }
-
-    return it->second;
-}
-
-static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
-    return vocab.type;
-}
-
-static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
-}
-
-static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
-}
-
-static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
-}
-
-static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
-}
-
-static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
-}
-
-static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
-}
-
-static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    GGML_ASSERT(llama_is_byte_token(vocab, id));
-    const auto & token_data = vocab.id_to_token.at(id);
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            auto buf = token_data.text.substr(3, 2);
-            return strtol(buf.c_str(), NULL, 16);
-        }
-        case LLAMA_VOCAB_TYPE_BPE: {
-            GGML_ABORT("fatal error");
-            //return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT?
-        }
-        case LLAMA_VOCAB_TYPE_WPM: {
-            GGML_ABORT("fatal error");
-        }
-        default:
-            GGML_ABORT("fatal error");
-    }
-}
-
-static void llama_escape_whitespace(std::string & text) {
-    replace_all(text, " ", "\xe2\x96\x81");
-}
-
-static void llama_unescape_whitespace(std::string & word) {
-    replace_all(word, "\xe2\x96\x81", " ");
-}
-
 struct llm_symbol {
     using index = int;
     index prev;
@@ -183,14 +104,13 @@ struct llm_bigram_spm {
 };
 
 struct llm_tokenizer_spm : llm_tokenizer {
-    llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
+    llm_tokenizer_spm(const llama_vocab & /*vocab*/) {}
 };
 
 struct llm_tokenizer_spm_session {
     llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {}
 
-    void tokenize(const std::string & text, std::vector & output) {
-
+    void tokenize(const std::string & text, std::vector & output) {
         // split string into utf8 chars
         int index = 0;
         size_t offs = 0;
@@ -249,13 +169,13 @@ struct llm_tokenizer_spm_session {
     }
 
 private:
-    void resegment(llm_symbol & symbol, std::vector & output) {
+    void resegment(llm_symbol & symbol, std::vector & output) {
         auto text = std::string(symbol.text, symbol.n);
-        auto token = vocab.token_to_id.find(text);
+        auto token = vocab.text_to_token(text);
 
         // Do we need to support is_unused?
-        if (token != vocab.token_to_id.end()) {
-            output.push_back((*token).second);
+        if (token != LLAMA_TOKEN_NULL) {
+            output.push_back(token);
             return;
         }
 
@@ -265,8 +185,8 @@ private:
             // output any symbols that did not form tokens as bytes.
             output.reserve(output.size() + symbol.n);
             for (int j = 0; j < (int)symbol.n; ++j) {
-                llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]);
-                output.push_back(token_id);
+                llama_token id = vocab.byte_to_token(symbol.text[j]);
+                output.push_back(id);
             }
             return;
         }
@@ -280,17 +200,17 @@ private:
             return;
         }
         const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
-        auto token = vocab.token_to_id.find(text);
+        auto token = vocab.text_to_token(text);
 
-        if (token == vocab.token_to_id.end()) {
+        if (token == LLAMA_TOKEN_NULL) {
             return;
         }
 
-        if (static_cast((*token).second) >= vocab.id_to_token.size()) {
+        if (static_cast(token) >= vocab.n_tokens()) {
             return;
         }
 
-        const auto & tok_data = vocab.id_to_token[(*token).second];
+        const auto & tok_data = vocab.get_token_data(token);
 
         llm_bigram_spm bigram;
         bigram.left  = left;
@@ -353,9 +273,9 @@ struct llm_bigram_bpe {
 };
 
 struct llm_tokenizer_bpe : llm_tokenizer {
-    llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() {
-        GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
-        switch (vocab.type_pre) {
+    llm_tokenizer_bpe(const llama_vocab & vocab) {
+        GGML_ASSERT(vocab.get_type() == LLAMA_VOCAB_TYPE_BPE);
+        switch (vocab.get_pre_type()) {
             case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
                 regex_exprs = {
                     // original regex from tokenizer.json
@@ -488,39 +408,38 @@ struct llm_tokenizer_bpe : llm_tokenizer {
 };
 
 struct llm_tokenizer_bpe_session {
-    llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab),
-        bpe_tokenizer(static_cast(vocab.tokenizer)) {}
+    llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
 
-    static void append(const llama_vocab::id token_id, std::vector & output)  {
+    static void append(const llama_token token_id, std::vector & output)  {
         output.push_back(token_id);
     }
 
-    bool append_bos(std::vector & output) const {
-        if (vocab.tokenizer_add_bos) {
-            GGML_ASSERT(vocab.special_bos_id != -1);
-            output.push_back(vocab.special_bos_id);
+    bool append_bos(std::vector & output) const {
+        if (vocab.get_add_bos()) {
+            GGML_ASSERT(vocab.token_bos() != LLAMA_TOKEN_NULL);
+            output.push_back(vocab.token_bos());
             return true;
         }
         return false;
     }
 
-    bool append_eos(std::vector & output) const {
-        if (vocab.tokenizer_add_eos) {
-            GGML_ASSERT(vocab.special_eos_id != -1);
-            output.push_back(vocab.special_eos_id);
+    bool append_eos(std::vector & output) const {
+        if (vocab.get_add_eos()) {
+            GGML_ASSERT(vocab.token_eos() != LLAMA_TOKEN_NULL);
+            output.push_back(vocab.token_eos());
             return true;
         }
         return false;
     }
 
-    void check_double_bos_eos(const std::vector & output) const {
-        if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+    void check_double_bos_eos(const std::vector & output) const {
+        if (vocab.get_add_bos() && output.size() >= 2 && output[1] == vocab.token_bos()) {
             LLAMA_LOG_WARN(
                 "%s: Added a BOS token to the prompt as specified by the model but the prompt "
                 "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
                 "Are you sure this is what you want?\n", __FUNCTION__);
         }
-        if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
+        if (vocab.get_add_eos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) {
             LLAMA_LOG_WARN(
                 "%s: Added a EOS token to the prompt as specified by the model but the prompt "
                 "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
@@ -528,9 +447,9 @@ struct llm_tokenizer_bpe_session {
         }
     }
 
-    void tokenize(const std::string & text, std::vector & output) {
+    void tokenize(const std::string & text, std::vector & output) {
         int final_prev_index = -1;
-        const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs);
+        const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs);
 
         symbols_final.clear();
 
@@ -541,7 +460,8 @@ struct llm_tokenizer_bpe_session {
             int index = 0;
             size_t offset = 0;
 
-            if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
+            //if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
+            if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
                 symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
                 offset = word.size();
             }
@@ -615,18 +535,18 @@ struct llm_tokenizer_bpe_session {
                 }
 
                 const std::string str = std::string(symbol.text, symbol.n);
-                const auto token = vocab.token_to_id.find(str);
+                const auto token = vocab.text_to_token(str);
 
-                if (token == vocab.token_to_id.end()) {
+                if (token == LLAMA_TOKEN_NULL) {
                     for (auto j = str.begin(); j != str.end(); ++j) {
                         std::string byte_str(1, *j);
-                        auto token_multibyte = vocab.token_to_id.find(byte_str);
-                        if (token_multibyte != vocab.token_to_id.end()) {
-                            output.push_back(token_multibyte->second);
+                        auto token_multibyte = vocab.text_to_token(byte_str);
+                        if (token_multibyte != LLAMA_TOKEN_NULL) {
+                            output.push_back(token_multibyte);
                         }
                     }
                 } else {
-                    output.push_back((*token).second);
+                    output.push_back(token);
                 }
             }
         }
@@ -660,7 +580,7 @@ private:
     }
 
     const llama_vocab & vocab;
-    const llm_tokenizer_bpe * bpe_tokenizer;
+    const llm_tokenizer_bpe & tokenizer;
 
     std::vector symbols;
     std::vector symbols_final;
@@ -672,14 +592,13 @@ private:
 //
 
 struct llm_tokenizer_wpm : llm_tokenizer {
-    llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
+    llm_tokenizer_wpm(const llama_vocab & /*vocab*/) {}
 };
 
 struct llm_tokenizer_wpm_session {
     llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {}
 
-    void tokenize(const std::string & text, std::vector & output) {
-        const auto & token_map = vocab.token_to_id;
+    void tokenize(const std::string & text, std::vector & output) {
         // normalize and split by whitespace
         std::vector words = preprocess(text);
         // bos token prepended already
@@ -702,10 +621,10 @@ struct llm_tokenizer_wpm_session {
             for (int i = 0; i < n; ++i) {
                 // loop through possible match length
                 bool match = false;
-                for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
-                    auto it = token_map.find(word1.substr(i, j - i));
-                    if (it != token_map.end()) {
-                        output.push_back(it->second);
+                for (int j = std::min(n, i + vocab.max_token_len() + 1); j > i; j--) {
+                    auto id = vocab.text_to_token(word1.substr(i, j - i));
+                    if (id != LLAMA_TOKEN_NULL) {
+                        output.push_back(id);
                         match = true;
                         i = j - 1;
                         break;
@@ -720,7 +639,7 @@ struct llm_tokenizer_wpm_session {
 
             // we didn't find any matches for this word
             if (current_tokens == output.size()) {
-                output.push_back(vocab.special_unk_id);
+                output.push_back(vocab.token_unk());
             }
         }
     }
@@ -789,45 +708,45 @@ private:
 //
 
 struct llm_tokenizer_ugm : llm_tokenizer {
-    llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() {
-        if (vocab.precompiled_charsmap.size() > 0) {
+    llm_tokenizer_ugm(const llama_vocab & vocab, const std::vector & precompiled_charsmap) {
+        if (precompiled_charsmap.size() > 0) {
             size_t charsmap_offset = 0;
 
             // First four bytes of precompiled_charsmap contains length of binary
             // blob containing XOR-compressed compact double array (XCDA) entries
-            uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0];
+            uint32_t xcda_blob_size = *(const uint32_t *) &precompiled_charsmap[0];
             charsmap_offset += sizeof(xcda_blob_size);
-            if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) {
+            if (xcda_blob_size + charsmap_offset >= precompiled_charsmap.size()) {
                 throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
             }
 
             // Next xcda_blob_size bytes contain entries of XOR-compressed compact
             // double array (XCDA). Each entry is bit-packed into a 32-bit integer.
-            xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset];
+            xcda_array = (const uint32_t *) &precompiled_charsmap[charsmap_offset];
             xcda_array_size = xcda_blob_size / sizeof(uint32_t);
             charsmap_offset += xcda_blob_size;
 
             // Remaining bytes of precompiled charsmap contain null-terminated
             // replacement strings for prefixes matched by the XCDA.
-            prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset];
-            prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset;
+            prefix_replacements = &precompiled_charsmap[charsmap_offset];
+            prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset;
         }
 
-        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
-            const auto &token_data = vocab.id_to_token[id];
+        for (uint32_t id = 0; id < vocab.n_tokens(); ++id) {
+            const auto & token_data = vocab.get_token_data(id);
 
-            if (llama_is_normal_token(vocab, id)) {
+            if (vocab.is_normal(id)) {
                 min_score = std::min(min_score, token_data.score);
                 max_score = std::max(max_score, token_data.score);
             }
 
-            if (llama_is_normal_token(vocab, id) ||
-                llama_is_user_defined_token(vocab, id) ||
-                llama_is_unused_token(vocab, id)) {
+            if (vocab.is_normal(id) ||
+                vocab.is_user_defined(id) ||
+                vocab.is_unused(id)) {
                 token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
             }
 
-            if (llama_is_user_defined_token(vocab, id)) {
+            if (vocab.is_user_defined(id)) {
                 user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
             }
         }
@@ -856,8 +775,7 @@ struct llm_tokenizer_ugm : llm_tokenizer {
 };
 
 struct llm_tokenizer_ugm_session {
-    llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab),
-        ugm_tokenizer(static_cast(vocab.tokenizer)) {}
+    llm_tokenizer_ugm_session(const llama_vocab & vocab, const llm_tokenizer_ugm & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
 
     /* This implementation is based on SentencePiece optimized Viterbi algorithm for
      * unigram language models. The general idea is to:
@@ -872,7 +790,7 @@ struct llm_tokenizer_ugm_session {
      * After processing the whole sequence we backtrack from the end to get
      * the best tokenization.
     */
-    void tokenize(const std::string & text, std::vector & output) {
+    void tokenize(const std::string & text, std::vector & output) {
         // get current size of output (for reversal later)
         size_t output_size = output.size();
 
@@ -885,9 +803,9 @@ struct llm_tokenizer_ugm_session {
         }
 
         // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
-        std::vector tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX});
+        std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
         // at the beginning tokenization score is zero
-        tokenization_results[0] = { vocab.special_unk_id, 0, 0 };
+        tokenization_results[0] = { vocab.token_unk(), 0, 0 };
 
         for (size_t input_offset = 0; input_offset < input_len;) {
             size_t prefix_offset = input_offset;
@@ -897,7 +815,7 @@ struct llm_tokenizer_ugm_session {
             // traverse the token matcher trie to find a matching token
             bool single_codepoint_token_found = false;
             const struct best_tokenization & current_best = tokenization_results[input_offset];
-            const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]);
+            const struct naive_trie * node = tokenizer.token_matcher.traverse(normalized[prefix_offset++]);
 
             while (prefix_offset <= input_len && node != NULL) {
                 // check if we found valid token in prefix
@@ -907,13 +825,13 @@ struct llm_tokenizer_ugm_session {
                         single_codepoint_token_found = true;
                     }
                     llama_token token_id = node->value;
-                    const auto & token_data = vocab.id_to_token[token_id];
+                    const auto & token_data = vocab.get_token_data(token_id);
 
                     // we set the user-defined token scores to 0 to make them more likely to be selected
                     // (normal token scores are log probabilities, so they are negative)
                     // score type is double here to make tokenization results exactly
                     // the same as in the HF tokenizer using SentencePiece
-                    const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score;
+                    const double token_score = vocab.is_user_defined(token_id) ? 0.0 : token_data.score;
                     const double challenger_score = current_best.score_sum + token_score;
                     struct best_tokenization & current_champ = tokenization_results[prefix_offset];
                     if (challenger_score > current_champ.score_sum) {
@@ -927,11 +845,11 @@ struct llm_tokenizer_ugm_session {
             // if we didn't find a valid token corresponding to the whole UTF code point
             // then use unknown token as the tokenization of this UTF code point
             if (!single_codepoint_token_found) {
-                const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score;
+                const double challenger_score = current_best.score_sum + tokenizer.unknown_token_score;
                 prefix_offset = input_offset + n_utf8_code_units;
                 struct best_tokenization & current_champ = tokenization_results[prefix_offset];
                 if (challenger_score > current_champ.score_sum) {
-                    struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score };
+                    struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
                     current_champ = challenger;
                 }
             }
@@ -944,7 +862,7 @@ struct llm_tokenizer_ugm_session {
         // merge sequences of consecutive unknown tokens into single unknown tokens
         bool is_prev_unknown = false;
         for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) {
-            bool is_unknown = tokenization.token_id == vocab.special_unk_id;
+            bool is_unknown = tokenization.token_id == vocab.token_unk();
             if (!(is_prev_unknown && is_unknown)) {
                 output.push_back(tokenization.token_id);
             }
@@ -971,11 +889,11 @@ private:
         normalized->clear();
         normalized->reserve(input.size() * 3);
 
-        const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " ";
+        const std::string space = vocab.get_escape_whitespaces() ? tokenizer.escaped_space : " ";
 
-        bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
-        bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
-        bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces;
+        const bool shall_prepend_space = !vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix();
+        const bool shall_append_space  =  vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix();
+        const bool shall_merge_spaces  =  vocab.get_remove_extra_whitespaces();
 
         bool is_space_prepended = false;
         bool processing_non_ws = false;
@@ -1067,7 +985,7 @@ private:
 
         // if input prefix matches some user-defined token return this token as normalization result
         auto user_defined_token_match =
-           ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
+           tokenizer.user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
         if (user_defined_token_match.second > 0) {
             return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
         }
@@ -1075,8 +993,8 @@ private:
         size_t longest_prefix_length = 0;
         size_t longest_prefix_offset = 0;
 
-        if (ugm_tokenizer->xcda_array_size > 0) {
-            struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size);
+        if (tokenizer.xcda_array_size > 0) {
+            struct xcda_array_view xcda_view(tokenizer.xcda_array, tokenizer.xcda_array_size);
 
             // Find the longest normalized sequence matching the input prefix by walking
             // the XOR-compressed compact double array (XCDA) starting from the root node
@@ -1112,10 +1030,10 @@ private:
 
         if (longest_prefix_length > 0) {
             // we have a match, so return the replacement sequence
-            if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) {
+            if (longest_prefix_offset >= tokenizer.prefix_replacements_size) {
                 throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
             }
-            const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset];
+            const char * prefix_replacement = &(tokenizer.prefix_replacements)[longest_prefix_offset];
             return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
         }
 
@@ -1132,7 +1050,7 @@ private:
     }
 
     const llama_vocab & vocab;
-    const llm_tokenizer_ugm * ugm_tokenizer;
+    const llm_tokenizer_ugm & tokenizer;
 };
 
 //
@@ -1194,15 +1112,15 @@ static std::vector llama_unescape_rwkv_token(const std::string & escape
 }
 
 struct llm_tokenizer_rwkv : llm_tokenizer {
-    llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() {
+    llm_tokenizer_rwkv(const llama_vocab & vocab) {
         // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
         // For now, we decode the vocab here into the lookup we'll use for tokenization.
 
         // build trie
-        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
-            const auto & token = vocab.id_to_token[id];
-            const auto data = llama_unescape_rwkv_token(token.text);
-            token_matcher.insert((const char *) data.data(), data.size(), id);
+        for (uint32_t id = 0; id < vocab.n_tokens(); ++id) {
+            const auto & data = vocab.get_token_data(id);
+            const auto text = llama_unescape_rwkv_token(data.text);
+            token_matcher.insert((const char *) text.data(), text.size(), id);
         }
     }
 
@@ -1210,16 +1128,15 @@ struct llm_tokenizer_rwkv : llm_tokenizer {
 };
 
 struct llm_tokenizer_rwkv_session {
-    llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab),
-        rwkv_tokenizer(static_cast(*vocab.tokenizer)) {}
+    llm_tokenizer_rwkv_session(const llama_vocab & vocab, const llm_tokenizer_rwkv & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
 
-    void tokenize(const std::string & text, std::vector & output) {
+    void tokenize(const std::string & text, std::vector & output) {
         uint32_t position = 0;
         while (position < text.size()) {
-            const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
+            const struct naive_trie * node = tokenizer.token_matcher.traverse(text[position]);
             if (node == NULL) {
                 // no matching token found, add unknown token
-                output.push_back(vocab.special_unk_id);
+                output.push_back(vocab.token_unk());
                 position += 1;
                 continue;
             }
@@ -1243,33 +1160,11 @@ struct llm_tokenizer_rwkv_session {
 
 private:
     const llama_vocab & vocab;
-    const llm_tokenizer_rwkv & rwkv_tokenizer;
+    const llm_tokenizer_rwkv & tokenizer;
 };
 
-void llama_vocab::init_tokenizer() {
-    switch (type) {
-        case LLAMA_VOCAB_TYPE_SPM:
-            tokenizer = new llm_tokenizer_spm(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_BPE:
-            tokenizer = new llm_tokenizer_bpe(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_WPM:
-            tokenizer = new llm_tokenizer_wpm(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_UGM:
-            tokenizer = new llm_tokenizer_ugm(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_RWKV:
-            tokenizer = new llm_tokenizer_rwkv(*this);
-            break;
-        default:
-            GGML_ABORT("unsupported vocab type");
-    }
-}
-
 //
-// (de-) tokenize
+// impl
 //
 
 typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
@@ -1278,7 +1173,7 @@ typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
 } FRAGMENT_BUFFER_VARIANT_TYPE;
 
 struct fragment_buffer_variant {
-    fragment_buffer_variant(llama_vocab::id _token)
+    fragment_buffer_variant(llama_token _token)
     :
         type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
         token(_token),
@@ -1289,7 +1184,7 @@ struct fragment_buffer_variant {
     fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
     :
         type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
-        token((llama_vocab::id) - 1),
+        token((llama_token) - 1),
         raw_text(_raw_text),
         offset(_offset),
         length(_length){
@@ -1299,20 +1194,955 @@ struct fragment_buffer_variant {
         }
 
     const FRAGMENT_BUFFER_VARIANT_TYPE type;
-    const llama_vocab::id token;
+    const llama_token token;
     const std::string _dummy;
     const std::string & raw_text;
     const uint64_t offset;
     const uint64_t length;
 };
 
+struct llama_vocab::impl {
+    uint32_t n_token_types = 0; // for BERT-style token types
+
+    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
+    enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+
+    int max_token_len = 0; // used for optimizing longest token search
+
+    // default LLaMA special tokens
+    // TODO: should we set all of these to LLAMA_TOKEN_NULL?
+    llama_token special_bos_id  = 1;
+    llama_token special_eos_id  = 2;
+    llama_token special_eot_id  = LLAMA_TOKEN_NULL;
+    llama_token special_eom_id  = LLAMA_TOKEN_NULL;
+    llama_token special_unk_id  = 0;
+    llama_token special_sep_id  = LLAMA_TOKEN_NULL;
+    llama_token special_pad_id  = LLAMA_TOKEN_NULL;
+    llama_token special_mask_id = LLAMA_TOKEN_NULL;
+
+    llama_token linefeed_id = 13;
+
+    // fim tokens
+    llama_token special_fim_pre_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_suf_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_mid_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_pad_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
+    llama_token special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
+
+    // tokenizer flags
+    bool add_space_prefix           = false;
+    bool add_bos                    = false;
+    bool add_eos                    = false;
+    bool ignore_merges              = false;
+    bool clean_spaces               = false;  // clean_up_tokenization_spaces
+    bool remove_extra_whitespaces   = false;
+    bool escape_whitespaces         = true;
+    bool treat_whitespace_as_suffix = false;
+
+    std::unordered_map token_to_id;
+    std::vector                      id_to_token;
+
+    std::vector cache_special_tokens;
+    std::vector cache_token_to_piece; // llama_token_to_piece(special = true);
+    struct pair_hash {
+        size_t operator()(const std::pair & p) const {
+            return std::hash{}(p.first) ^  //create some hash for pair
+                   (std::hash{}(p.second) << 1);
+        }
+    };
+    std::unordered_map, int, pair_hash> bpe_ranks;
+
+    // set of all tokens that cause "end of generation"
+    std::set special_eog_ids;
+
+    std::unique_ptr tokenizer;
+
+    std::vector precompiled_charsmap;
+
+    impl(const llama_vocab & vocab) : vocab(vocab) {
+    }
+
+    ~impl() = default;
+
+    void load(llama_model_loader & ml, const LLM_KV & kv);
+
+    enum llama_vocab_type get_type() const;
+
+    std::string type_name() const;
+
+    bool is_normal      (llama_token id) const;
+    bool is_unknown     (llama_token id) const;
+    bool is_control     (llama_token id) const;
+    bool is_byte        (llama_token id) const;
+    bool is_user_defined(llama_token id) const;
+    bool is_unused      (llama_token id) const;
+    bool is_eog         (llama_token id) const;
+
+    uint8_t token_to_byte(llama_token id) const;
+
+    llama_token_attr token_get_attr(llama_token id) const;
+
+    void init_tokenizer(enum llama_vocab_type type);
+
+    void tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const;
+
+    std::string token_to_piece_for_cache(
+                  llama_token   token,
+                         bool   special) const;
+
+
+    std::vector tokenize(
+            const std::string & raw_text,
+                         bool   add_special,
+                         bool   parse_special = false) const;
+
+    int32_t tokenize(
+                   const char * text,
+                      int32_t   text_len,
+                  llama_token * tokens,
+                      int32_t   n_tokens_max,
+                         bool   add_special,
+                         bool   parse_special) const;
+
+    // does not write null-terminator to buf
+    int32_t token_to_piece(
+                  llama_token   token,
+                         char * buf,
+                      int32_t   length,
+                      int32_t   lstrip,
+                         bool   special) const;
+
+    // use cached data
+    const std::string & token_to_piece(llama_token token) const;
+
+    int32_t detokenize(
+            const llama_token * tokens,
+                      int32_t   n_tokens,
+                         char * text,
+                      int32_t   text_len_max,
+                         bool   remove_special,
+                         bool   unparse_special) const;
+
+    std::string detokenize(
+            const std::vector & tokens,
+                                      bool   special) const;
+
+    void print_info() const;
+
+private:
+    const llama_vocab & vocab;
+};
+
+void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+    struct gguf_context * ctx = ml.meta.get();
+
+    // determine vocab type
+    {
+        std::string tokenizer_model;
+        std::string tokenizer_pre;
+
+        ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
+        ml.get_key(LLM_KV_TOKENIZER_PRE,   tokenizer_pre, false);
+
+        ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false);
+
+        if (tokenizer_model == "no_vocab" || tokenizer_model == "none") {
+            type = LLAMA_VOCAB_TYPE_NONE;
+
+            // default special tokens
+            special_bos_id  = LLAMA_TOKEN_NULL;
+            special_eos_id  = LLAMA_TOKEN_NULL;
+            special_unk_id  = LLAMA_TOKEN_NULL;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = LLAMA_TOKEN_NULL;
+            special_mask_id = LLAMA_TOKEN_NULL;
+            linefeed_id     = LLAMA_TOKEN_NULL;
+
+            // read vocab size from metadata
+            uint32_t n_tokens = 0;
+            if (ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) {
+                LLAMA_LOG_WARN("%s: adding %u dummy tokens\n", __func__, n_tokens);
+                id_to_token.resize(n_tokens);
+            }
+
+            return;
+        }
+
+        if (tokenizer_model == "llama") {
+            type = LLAMA_VOCAB_TYPE_SPM;
+
+            // default special tokens
+            special_bos_id  = 1;
+            special_eos_id  = 2;
+            special_unk_id  = 0;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = LLAMA_TOKEN_NULL;
+            special_mask_id = LLAMA_TOKEN_NULL;
+        } else if (tokenizer_model == "bert") {
+            type = LLAMA_VOCAB_TYPE_WPM;
+
+            // default special tokens
+            special_bos_id  = 101;
+            special_eos_id  = LLAMA_TOKEN_NULL;
+            special_unk_id  = 100;
+            special_sep_id  = 102;
+            special_pad_id  = 0;
+            special_mask_id = 103;
+        } else if (tokenizer_model == "gpt2") {
+            type = LLAMA_VOCAB_TYPE_BPE;
+
+            // read bpe merges and populate bpe ranks
+            const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
+            if (merges_keyidx == -1) {
+                throw std::runtime_error("cannot find tokenizer merges in model file\n");
+            }
+
+            const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
+            for (int i = 0; i < n_merges; i++) {
+                const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
+                //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
+
+                std::string first;
+                std::string second;
+
+                const size_t pos = word.find(' ', 1);
+
+                if (pos != std::string::npos) {
+                    first  = word.substr(0, pos);
+                    second = word.substr(pos + 1);
+                }
+
+                bpe_ranks.emplace(std::make_pair(first, second), i);
+            }
+
+            // default special tokens
+            special_bos_id  = 11;
+            special_eos_id  = 11;
+            special_unk_id  = LLAMA_TOKEN_NULL;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = LLAMA_TOKEN_NULL;
+            special_mask_id = LLAMA_TOKEN_NULL;
+        } else if (tokenizer_model == "t5") {
+            type = LLAMA_VOCAB_TYPE_UGM;
+
+            // default special tokens
+            special_bos_id  = LLAMA_TOKEN_NULL;
+            special_eos_id  = 1;
+            special_unk_id  = 2;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = 0;
+            special_mask_id = LLAMA_TOKEN_NULL;
+
+            const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
+            if (precompiled_charsmap_keyidx != -1) {
+                size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
+                const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
+                precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
+#ifdef IS_BIG_ENDIAN
+                // correct endiannes of data in precompiled_charsmap binary blob
+                uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0];
+                *xcda_blob_size = __builtin_bswap32(*xcda_blob_size);
+                assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap);
+                size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t);
+                uint32_t * xcda_array = (uint32_t *) &precompiled_charsmap[sizeof(uint32_t)];
+                for (size_t i = 0; i < xcda_array_size; ++i) {
+                    xcda_array[i] = __builtin_bswap32(xcda_array[i]);
+                }
+#endif
+            }
+        } else if (tokenizer_model == "rwkv") {
+            type = LLAMA_VOCAB_TYPE_RWKV;
+
+            // default special tokens
+            special_bos_id = LLAMA_TOKEN_NULL;
+            special_eos_id = LLAMA_TOKEN_NULL;
+            special_unk_id = LLAMA_TOKEN_NULL;
+            special_sep_id = LLAMA_TOKEN_NULL;
+            special_pad_id = LLAMA_TOKEN_NULL;
+        } else {
+            throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
+        }
+
+        // for now, only BPE models have pre-tokenizers
+        if (type == LLAMA_VOCAB_TYPE_BPE) {
+            add_space_prefix = false;
+            clean_spaces = true;
+            if (tokenizer_pre == "default") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            } else if (
+                    tokenizer_pre == "llama3"   ||
+                    tokenizer_pre == "llama-v3" ||
+                    tokenizer_pre == "llama-bpe"||
+                    tokenizer_pre == "falcon3") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
+                ignore_merges = true;
+                add_bos = true;
+            } else if (
+                    tokenizer_pre == "deepseek-llm") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "deepseek-coder") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "deepseek-v3") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "falcon") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON;
+            } else if (
+                    tokenizer_pre == "mpt") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_MPT;
+            } else if (
+                    tokenizer_pre == "starcoder") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_STARCODER;
+            } else if (
+                    tokenizer_pre == "gpt-2"   ||
+                    tokenizer_pre == "phi-2"   ||
+                    tokenizer_pre == "jina-es" ||
+                    tokenizer_pre == "jina-de" ||
+                    tokenizer_pre == "gigachat"   ||
+                    tokenizer_pre == "jina-v1-en" ||
+                    tokenizer_pre == "jina-v2-es" ||
+                    tokenizer_pre == "jina-v2-de" ||
+                    tokenizer_pre == "jina-v2-code" ||
+                    tokenizer_pre == "roberta-bpe") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
+            } else if (
+                    tokenizer_pre == "refact") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
+            } else if (
+                tokenizer_pre == "command-r") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "qwen2" ||
+                    tokenizer_pre == "deepseek-r1-qwen") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "stablelm2") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2;
+            } else if (
+                tokenizer_pre == "olmo") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_OLMO;
+            } else if (
+                tokenizer_pre == "dbrx") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DBRX;
+            } else if (
+                tokenizer_pre == "smaug-bpe") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_SMAUG;
+            } else if (
+                tokenizer_pre == "poro-chat") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_PORO;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "chatglm-bpe") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
+                special_bos_id = LLAMA_TOKEN_NULL;
+            } else if (
+                tokenizer_pre == "viking") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_VIKING;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "jais") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS;
+            } else if (
+                tokenizer_pre == "tekken") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
+                clean_spaces = false;
+                ignore_merges = true;
+                add_bos = true;
+            } else if (
+                tokenizer_pre == "smollm") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "codeshell") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
+            } else if (
+                tokenizer_pre == "bloom") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_BLOOM;
+            } else if (
+                tokenizer_pre == "gpt3-finnish") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH;
+            } else if (
+                tokenizer_pre == "exaone") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE;
+            } else if (
+                tokenizer_pre == "chameleon") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
+                add_bos = true;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "minerva-7b") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_MINERVA;
+            } else if (
+                tokenizer_pre == "megrez") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+            } else {
+                LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            }
+        } else if (type == LLAMA_VOCAB_TYPE_SPM) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_space_prefix = true;
+            clean_spaces = false;
+            add_bos = true;
+            add_eos = false;
+        } else if (type == LLAMA_VOCAB_TYPE_WPM) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_space_prefix = false;
+            clean_spaces = true;
+            add_bos = true;
+            add_eos = false;
+        } else if (type == LLAMA_VOCAB_TYPE_UGM) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_bos = false;
+            add_eos = true;
+        } else if (type == LLAMA_VOCAB_TYPE_RWKV) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_space_prefix = false;
+            clean_spaces = false;
+            add_bos = false;
+            add_eos = false;
+        } else {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+        }
+
+        ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX,      add_space_prefix,         false);
+        ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false);
+    }
+
+    const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
+    if (token_idx == -1) {
+        throw std::runtime_error("cannot find tokenizer vocab in model file\n");
+    }
+
+    const float * scores = nullptr;
+    const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
+    if (score_idx != -1) {
+        scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
+    }
+
+    const int * toktypes = nullptr;
+    const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
+    if (toktype_idx != -1) {
+        toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
+    }
+
+    uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx);
+    id_to_token.resize(n_tokens);
+
+    for (uint32_t i = 0; i < n_tokens; i++) {
+        std::string word = gguf_get_arr_str(ctx, token_idx, i);
+        if (word.empty()) {
+            LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i);
+            word = "[EMPTY_" + std::to_string(i) + "]";
+        }
+
+        token_to_id[word] = i;
+        max_token_len = std::max(max_token_len, (int) word.size());
+
+        auto & token_data = id_to_token[i];
+        token_data.text  = std::move(word);
+        token_data.score = scores ? scores[i] : 0.0f;
+        token_data.attr  = LLAMA_TOKEN_ATTR_NORMAL;
+
+        if (toktypes) {  //TODO: remove, required until per token attributes are available from GGUF file
+            switch(toktypes[i]) {
+                case LLAMA_TOKEN_TYPE_UNKNOWN:      token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN;      break;
+                case LLAMA_TOKEN_TYPE_UNUSED:       token_data.attr = LLAMA_TOKEN_ATTR_UNUSED;       break;
+                case LLAMA_TOKEN_TYPE_NORMAL:       token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;       break;
+                case LLAMA_TOKEN_TYPE_CONTROL:      token_data.attr = LLAMA_TOKEN_ATTR_CONTROL;      break;
+                case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
+                case LLAMA_TOKEN_TYPE_BYTE:         token_data.attr = LLAMA_TOKEN_ATTR_BYTE;         break;
+                case LLAMA_TOKEN_TYPE_UNDEFINED:    token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
+                default:                            token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
+            }
+        }
+    }
+    GGML_ASSERT(id_to_token.size() == token_to_id.size());
+
+    init_tokenizer(type);
+
+    // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
+    if (type == LLAMA_VOCAB_TYPE_SPM) {
+        try {
+            linefeed_id = vocab.byte_to_token('\n');
+        } catch (const std::exception & e) {
+            LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
+            linefeed_id = special_pad_id;
+        }
+    } else if (type == LLAMA_VOCAB_TYPE_WPM) {
+        linefeed_id = special_pad_id;
+    } else if (type == LLAMA_VOCAB_TYPE_RWKV) {
+        const std::vector ids = tokenize("\n", false);
+        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+        linefeed_id = ids[0];
+    } else {
+        const std::vector ids = tokenize("\n", false);
+
+        //GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+        if (ids.empty()) {
+            LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__);
+            linefeed_id = special_pad_id;
+        } else {
+            linefeed_id = ids[0];
+        }
+    }
+
+    // special tokens
+    {
+        const std::vector> special_token_types = {
+            { LLM_KV_TOKENIZER_BOS_ID,     special_bos_id     },
+            { LLM_KV_TOKENIZER_EOS_ID,     special_eos_id     },
+            { LLM_KV_TOKENIZER_EOT_ID,     special_eot_id     },
+            { LLM_KV_TOKENIZER_EOM_ID,     special_eom_id     },
+            { LLM_KV_TOKENIZER_UNK_ID,     special_unk_id     },
+            { LLM_KV_TOKENIZER_SEP_ID,     special_sep_id     },
+            { LLM_KV_TOKENIZER_PAD_ID,     special_pad_id     },
+            { LLM_KV_TOKENIZER_MASK_ID,    special_mask_id    },
+            { LLM_KV_TOKENIZER_FIM_PRE_ID, special_fim_pre_id },
+            { LLM_KV_TOKENIZER_FIM_SUF_ID, special_fim_suf_id },
+            { LLM_KV_TOKENIZER_FIM_MID_ID, special_fim_mid_id },
+            { LLM_KV_TOKENIZER_FIM_PAD_ID, special_fim_pad_id },
+            { LLM_KV_TOKENIZER_FIM_REP_ID, special_fim_rep_id },
+            { LLM_KV_TOKENIZER_FIM_SEP_ID, special_fim_sep_id },
+
+            // deprecated
+            { LLM_KV_TOKENIZER_PREFIX_ID, special_fim_pre_id },
+            { LLM_KV_TOKENIZER_SUFFIX_ID, special_fim_suf_id },
+            { LLM_KV_TOKENIZER_MIDDLE_ID, special_fim_mid_id },
+        };
+
+        for (const auto & it : special_token_types) {
+            const std::string & key = kv(std::get<0>(it));
+            int32_t & id = std::get<1>(it);
+
+            uint32_t new_id;
+            if (!ml.get_key(std::get<0>(it), new_id, false)) {
+                continue;
+            }
+            if (new_id >= id_to_token.size()) {
+                LLAMA_LOG_WARN("%s: bad special token: '%s' = %u, using default id %d\n",
+                    __func__, key.c_str(), new_id, id);
+            } else {
+                id = new_id;
+            }
+        }
+
+        // Handle add_bos and add_eos
+        {
+            bool temp = true;
+
+            if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
+                add_bos = temp;
+            }
+            if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
+                add_eos = temp;
+            }
+        }
+
+        // auto-detect special tokens by text
+        // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_...
+        //       for now, we apply this workaround to find the tokens based on their text
+
+        for (const auto & t : token_to_id) {
+            // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
+            if (special_eot_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|eot_id|>"
+                        || t.first == "<|im_end|>"
+                        || t.first == "<|end|>"
+                        || t.first == ""
+                        || t.first == "<|endoftext|>"
+                        || t.first == ""
+                        || t.first == "<|end▁of▁sentence|>" // DeepSeek
+                   ) {
+                    special_eot_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find EOM token: "<|eom_id|>"
+            if (special_eom_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|eom_id|>"
+                        ) {
+                    special_eom_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (special_fim_pre_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        ) {
+                    special_fim_pre_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (special_fim_suf_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        ) {
+                    special_fim_suf_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (special_fim_mid_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        ) {
+                    special_fim_mid_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (special_fim_pad_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    special_fim_pad_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (special_fim_rep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    special_fim_rep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SEP token: "<|file_sep|>"
+            if (special_fim_sep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    special_fim_sep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+        }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        special_eog_ids.clear();
+
+        if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
+            special_eog_ids.insert(special_fim_pad_id);
+        }
+
+        if (special_fim_rep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_rep_id) == 0) {
+            special_eog_ids.insert(special_fim_rep_id);
+        }
+
+        if (special_fim_sep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_sep_id) == 0) {
+            special_eog_ids.insert(special_fim_sep_id);
+        }
+
+        for (const auto & t : token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == ""
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == ""
+               ) {
+                special_eog_ids.insert(t.second);
+                if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
+                    id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            } else {
+                // token is control, but not marked as EOG -> print a debug log
+                if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
+            }
+        }
+
+        // sanity checks
+        if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
+            special_eog_ids.insert(special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eot_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eot_id) == 0) {
+            special_eog_ids.insert(special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eom_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eom_id) == 0) {
+            special_eog_ids.insert(special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+    }
+
+    // build special tokens cache
+    {
+        for (llama_token id = 0; id < (llama_token) n_tokens; ++id) {
+            if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
+                cache_special_tokens.push_back(id);
+            }
+        }
+
+        std::sort(cache_special_tokens.begin(), cache_special_tokens.end(),
+            [&] (const llama_token a, const llama_token b) {
+                return id_to_token[a].text.size() > id_to_token[b].text.size();
+            }
+        );
+
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t) cache_special_tokens.size());
+    }
+
+    // build token to piece cache
+    {
+        size_t size_cache = 0;
+
+        std::vector cache(n_tokens);
+
+        for (uint32_t id = 0; id < n_tokens; ++id) {
+            cache[id] = token_to_piece_for_cache(id, true);
+
+            size_cache += cache[id].size();
+        }
+
+        std::swap(cache_token_to_piece, cache);
+
+        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
+    }
+
+    // Handle per token attributes
+    //NOTE: Each model customizes per token attributes.
+    //NOTE: Per token attributes are missing from the GGUF file.
+    //TODO: Extract attributes from GGUF file.
+    {
+        auto _contains_any = [] (const std::string & str, const std::vector & substrs) -> bool {
+            for (const auto & substr : substrs) {
+                if (str.find(substr) < std::string::npos) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        auto _set_tokenid_attr = [&] (const llama_token id, llama_token_attr attr, bool value) {
+            uint32_t current = id_to_token.at(id).attr;
+            current = value ? (current | attr) : (current & ~attr);
+            id_to_token[id].attr = (llama_token_attr) current;
+        };
+
+        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
+            _set_tokenid_attr(token_to_id.at(token), attr, value);
+        };
+
+        std::string model_name;
+        std::string tokenizer_pre;
+
+        ml.get_key(LLM_KV_GENERAL_NAME,  model_name,    false);
+        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+
+        // model name to lowercase
+        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
+            [] (const std::string::value_type x) {
+                return std::tolower(x);
+            }
+        );
+
+        // set attributes by model/tokenizer name
+        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
+            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
+            for (auto id : cache_special_tokens) {
+                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {""}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {"", "", "<|endoftext|>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
+            }
+        }
+    }
+}
+
+enum llama_vocab_type llama_vocab::impl::get_type() const {
+    return type;
+}
+
+std::string llama_vocab::impl::type_name() const{
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
+        default:                    return "unknown";
+    }
+}
+
+bool llama_vocab::impl::is_normal(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
+}
+
+bool llama_vocab::impl::is_unknown(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
+}
+
+bool llama_vocab::impl::is_control(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
+}
+
+bool llama_vocab::impl::is_byte(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
+}
+
+bool llama_vocab::impl::is_user_defined(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
+}
+
+bool llama_vocab::impl::is_unused(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
+}
+
+bool llama_vocab::impl::is_eog(llama_token id) const {
+    return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
+}
+
+uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    GGML_ASSERT(is_byte(id));
+    const auto & token_data = id_to_token.at(id);
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            auto buf = token_data.text.substr(3, 2);
+            return strtol(buf.c_str(), NULL, 16);
+        }
+        case LLAMA_VOCAB_TYPE_BPE: {
+            GGML_ABORT("fatal error");
+        }
+        case LLAMA_VOCAB_TYPE_WPM: {
+            GGML_ABORT("fatal error");
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token_attr llama_vocab::impl::token_get_attr(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token.at(id).attr;
+}
+
+void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
+    LLAMA_LOG_DEBUG("%s: initializing tokenizer for type %d\n", __func__, type);
+
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            tokenizer = std::make_unique(vocab, precompiled_charsmap);
+            break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            tokenizer = std::make_unique(vocab);
+            break;
+        default:
+            GGML_ABORT("unsupported vocab type");
+    }
+}
+
+//
+// (de-) tokenize
+//
+
 // #define PRETOKENIZERDEBUG
 
-static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer, bool parse_special) {
+void llama_vocab::impl::tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const {
     // for each special token
-    for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
-        const auto & data = vocab.id_to_token[special_id];
-        const auto & special_token = data.text;
+    for (const llama_token special_id : cache_special_tokens) {
+        const auto & data = vocab.get_token_data(special_id);
+        const auto & text = data.text;
 
         if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
             // Ignore control and unknown tokens when parse_special == false
@@ -1339,13 +2169,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                     // find the first occurrence of a given special token in this fragment
                     //  passing offset argument only limit the "search area" but match coordinates
                     //  are still relative to the source full raw_text
-                    auto match = raw_text.find(special_token, raw_text_base_offset);
+                    auto match = raw_text.find(text, raw_text_base_offset);
 
                     // no occurrences found, stop processing this fragment for a given special token
                     if (match == std::string::npos) break;
 
                     // check if match is within bounds of offset <-> length
-                    if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
+                    if (match + text.length() > raw_text_base_offset + raw_text_base_length) break;
 
 #ifdef PRETOKENIZERDEBUG
                     LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
@@ -1380,9 +2210,9 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                     it++;
 
                     // right
-                    if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
-                        int64_t right_reminder_offset = match + special_token.length();
-                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
+                    if (match + text.length() < raw_text_base_offset + raw_text_base_length) {
+                        int64_t right_reminder_offset = match + text.length();
+                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + text.length());
 
                         if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
                             while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
@@ -1403,7 +2233,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                         if (source == 0) {
                             buffer.erase_after(buffer.before_begin());
                         } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
                         }
 
                         // repeat for the right side
@@ -1417,7 +2247,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
                         if (source == 0) {
                             buffer.erase_after(buffer.before_begin());
                         } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source-1)));
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
                         }
                         break;
                     }
@@ -1428,322 +2258,29 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
     }
 }
 
-std::vector llama_tokenize_internal(
-        const llama_vocab & vocab,
-        std::string raw_text,
-        bool add_special,
-        bool parse_special) {
-    GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
-
-    std::vector output;
-    std::forward_list fragment_buffer;
-
-    if (!raw_text.empty()) {
-        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
-        tokenizer_st_partition(vocab, fragment_buffer, parse_special);
+// NOTE: avoid ever using this except for building the token_to_piece caches
+std::string llama_vocab::impl::token_to_piece_for_cache(llama_token token, bool special) const {
+    std::string piece;
+    piece.resize(piece.capacity());  // using string internal cache
+    const int n_chars = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
+    }
+    else {
+        piece.resize(n_chars);
     }
 
-    switch (vocab.type) {
-        case LLAMA_VOCAB_TYPE_SPM:
-            {
-                // OG tokenizer behavior:
-                //
-                // tokenizer.encode('', add_special_tokens=True)  returns [1]
-                // tokenizer.encode('', add_special_tokens=False) returns []
-
-                bool is_prev_special = true;  // prefix with space if first token
-
-                if (add_special && vocab.tokenizer_add_bos) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                    is_prev_special = true;
-                }
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-                        // prefix with space if previous is special
-                        if (vocab.tokenizer_add_space_prefix && is_prev_special) {
-                            raw_text = " " + raw_text;
-                        }
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        llama_escape_whitespace(raw_text);
-                        llm_tokenizer_spm_session session(vocab);
-                        session.tokenize(raw_text, output);
-                        is_prev_special = false;
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                        is_prev_special = true;
-                    }
-                }
-
-                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.tokenizer_add_eos) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_BPE:
-            {
-                llm_tokenizer_bpe_session session(vocab);
-                // it calls some other methods that are not exist in llm_tokenizer,
-                // here just cast it to bpe tokenizer object
-                if (add_special) {
-                    session.append_bos(output);
-                }
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        session.append(fragment.token, output);
-                    }
-                }
-
-                if (add_special) {
-                    session.append_eos(output);
-                    session.check_double_bos_eos(output);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_WPM:
-            {
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_cls_id != -1);
-                    output.push_back(vocab.special_cls_id);
-                }
-
-                llm_tokenizer_wpm_session session(vocab);
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_sep_id != -1);
-                    output.push_back(vocab.special_sep_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_UGM:
-            {
-                if (add_special && vocab.tokenizer_add_bos) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
-                }
-                llm_tokenizer_ugm_session session(vocab);
-
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-
-                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.tokenizer_add_eos) {
-                    GGML_ASSERT(vocab.special_eos_id != -1);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_RWKV:
-            {
-                llm_tokenizer_rwkv_session session(vocab);
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_NONE:
-            GGML_ABORT("fatal error");
-    }
-
-    return output;
+    return piece;
 }
 
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    static const char * hex = "0123456789ABCDEF";
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
-            auto token = vocab.token_to_id.find(buf);
-            if (token != vocab.token_to_id.end()) {
-                return (*token).second;
-            }
-            // Try to fall back to just the byte as a string
-            const char buf2[2] = { (char)ch, 0 };
-            return vocab.token_to_id.at(buf2);
-        }
-        case LLAMA_VOCAB_TYPE_WPM:
-        case LLAMA_VOCAB_TYPE_BPE: {
-            return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
-        }
-        default:
-            GGML_ABORT("fatal error");
-    }
+static void llama_escape_whitespace(std::string & text) {
+    replace_all(text, " ", "\xe2\x96\x81");
 }
 
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].text.c_str();
-}
-
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].score;
-}
-
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].attr;
-}
-
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
-    return token != -1 && vocab.special_eog_ids.count(token) > 0;
-}
-
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
-    return llama_is_control_token(vocab, token);
-}
-
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
-    return vocab.type != LLAMA_VOCAB_TYPE_WPM ? vocab.special_bos_id : vocab.special_cls_id;
-}
-
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eos_id;
-}
-
-llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eot_id;
-}
-
-llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eom_id;
-}
-
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
-    return vocab.special_cls_id;
-}
-
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_sep_id;
-}
-
-llama_token llama_token_nl_impl(const struct llama_vocab & vocab) {
-    return vocab.linefeed_id;
-}
-
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_pad_id;
-}
-
-bool llama_add_bos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_bos;
-}
-
-bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_eos;
-}
-
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pre_id;
-}
-
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_mid_id;
-}
-
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_suf_id;
-}
-
-llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pre_id;
-}
-
-llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_suf_id;
-}
-
-llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_mid_id;
-}
-
-llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pad_id;
-}
-
-llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_rep_id;
-}
-
-llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_sep_id;
-}
-
-int32_t llama_tokenize_impl(
-        const struct llama_vocab & vocab,
-                      const char * text,
-                         int32_t   text_len,
-                     llama_token * tokens,
-                         int32_t   n_tokens_max,
-                            bool   add_special,
-                            bool   parse_special) {
-    auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
-    if (n_tokens_max < (int) res.size()) {
-        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
-        return -((int) res.size());
-    }
-
-    for (size_t i = 0; i < res.size(); i++) {
-        tokens[i] = res[i];
-    }
-
-    return res.size();
+static void llama_unescape_whitespace(std::string & word) {
+    replace_all(word, "\xe2\x96\x81", " ");
 }
 
 static std::string llama_decode_text(const std::string & text) {
@@ -1766,11 +2303,185 @@ static std::string llama_decode_text(const std::string & text) {
     return decoded_text;
 }
 
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
+std::vector llama_vocab::impl::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
+    std::vector output;
+    std::forward_list fragment_buffer;
+
+    if (!raw_text.empty()) {
+        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
+        tokenizer_st_partition(fragment_buffer, parse_special);
+    }
+
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            {
+                // OG tokenizer behavior:
+                //
+                // tokenizer.encode('', add_special_tokens=True)  returns [1]
+                // tokenizer.encode('', add_special_tokens=False) returns []
+
+                bool is_prev_special = true;  // prefix with space if first token
+
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                    is_prev_special = true;
+                }
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text;
+
+                        // prefix with space if previous is special
+                        if (add_space_prefix && is_prev_special) {
+                            text = ' ';
+                        }
+
+                        text += fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        llama_escape_whitespace(text);
+                        llm_tokenizer_spm_session session(vocab);
+                        session.tokenize(text, output);
+                        is_prev_special = false;
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                        is_prev_special = true;
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            {
+                llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get()));
+                // it calls some other methods that are not exist in llm_tokenizer,
+                // here just cast it to bpe tokenizer object
+                if (add_special) {
+                    session.append_bos(output);
+                }
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        session.append(fragment.token, output);
+                    }
+                }
+
+                if (add_special) {
+                    session.append_eos(output);
+                    session.check_double_bos_eos(output);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            {
+                if (add_special) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+
+                llm_tokenizer_wpm_session session(vocab);
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special) {
+                    GGML_ASSERT(special_sep_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_sep_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            {
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+                llm_tokenizer_ugm_session session(vocab, *static_cast(tokenizer.get()));
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            {
+                llm_tokenizer_rwkv_session session(vocab, *static_cast(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_NONE:
+            GGML_ABORT("fatal error");
+    }
+
+    return output;
+}
+
+int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
     // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
     static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
-    const llama_token_attr attr = llama_token_get_attr_impl(vocab, token);
+    const llama_token_attr attr = token_get_attr(token);
     if (!special && (attr & attr_special)) {
         return 0;
     }
@@ -1791,7 +2502,7 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
 
     // if we have a cache - use it
     {
-        const auto & cache = vocab.cache_token_to_piece;
+        const auto & cache = cache_token_to_piece;
 
         if (!cache.empty()) {
             const auto & result = cache.at(token);
@@ -1799,9 +2510,9 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
         }
     }
 
-    if (0 <= token && token < (int32_t) vocab.id_to_token.size()) {
-        const std::string & token_text = vocab.id_to_token[token].text;
-        switch (llama_vocab_get_type(vocab)) {
+    if (0 <= token && token < (int32_t) id_to_token.size()) {
+        const std::string & token_text = id_to_token[token].text;
+        switch (get_type()) {
             case LLAMA_VOCAB_TYPE_WPM:
             case LLAMA_VOCAB_TYPE_SPM:
             case LLAMA_VOCAB_TYPE_UGM: {
@@ -1816,7 +2527,7 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
                     return _try_copy(result.data(), result.size());
                 }
                 if (attr & LLAMA_TOKEN_ATTR_BYTE) {
-                    char byte = (char) llama_token_to_byte(vocab, token);
+                    char byte = (char) token_to_byte(token);
                     return _try_copy((char*) &byte, 1);
                 }
                 break;
@@ -1852,43 +2563,46 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
     return 0;
 }
 
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
+const std::string & llama_vocab::impl::token_to_piece(llama_token token) const {
+    return cache_token_to_piece.at(token);
+}
+
+int32_t llama_vocab::impl::detokenize(
                const llama_token * tokens,
                          int32_t   n_tokens,
                             char * text,
                          int32_t   text_len_max,
                             bool   remove_special,
-                            bool   unparse_special) {
-    if (vocab.type == LLAMA_VOCAB_TYPE_NONE) {
+                            bool   unparse_special) const {
+    if (type == LLAMA_VOCAB_TYPE_NONE) {
         return 0;
     }
 
-    GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
 
     int32_t avail = text_len_max;
     int32_t total = 0;
 
     // remove the leading space
-    bool remove_space = vocab.tokenizer_add_space_prefix;
+    bool remove_space = add_space_prefix;
 
-    if (remove_special && vocab.tokenizer_add_bos) {
-        if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) {
+    if (remove_special && add_bos) {
+        if (n_tokens > 0 && tokens[0] == special_bos_id) {
             remove_space = false;
             n_tokens--;
             tokens++;
         }
     }
 
-    if (remove_special && vocab.tokenizer_add_eos) {
-        if (n_tokens > 0 && tokens[n_tokens-1] == vocab.special_eos_id) {
+    if (remove_special && add_eos) {
+        if (n_tokens > 0 && tokens[n_tokens - 1] == special_eos_id) {
             n_tokens--;
         }
     }
 
     for (int32_t i = 0; i < n_tokens; ++i) {
         GGML_ASSERT(avail >= 0);
-        int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special);
+        int32_t n_chars = token_to_piece(tokens[i], text, avail, remove_space, unparse_special);
         remove_space = false;
         if (n_chars < 0) {
             avail = 0;
@@ -1904,7 +2618,7 @@ int32_t llama_detokenize_impl(
         return -total;
     }
 
-    if (vocab.tokenizer_clean_spaces) {
+    if (clean_spaces) {
         text -= total;  // restart text
 
         // first pass: characters ?!.,  //TODO: where do these characters come from?
@@ -1965,13 +2679,321 @@ int32_t llama_detokenize_impl(
     return total <= text_len_max ? total : -total;
 }
 
-std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector & tokens, bool special) {
+void llama_vocab::impl::print_info() const {
+    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, type_name().c_str());
+    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, vocab.n_tokens());
+    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
+
+    // special tokens
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token[special_bos_id].text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token[special_eos_id].text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token[special_eot_id].text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token[special_eom_id].text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token[special_unk_id].text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token[special_sep_id].text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token[special_pad_id].text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token[special_mask_id].text.c_str() ); }
+
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token[linefeed_id].text.c_str() ); }
+
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token[special_fim_pre_id].text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token[special_fim_suf_id].text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token[special_fim_mid_id].text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token[special_fim_pad_id].text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token[special_fim_rep_id].text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token[special_fim_sep_id].text.c_str() ); }
+
+    for (const auto & id : special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token[id].text.c_str() );
+    }
+
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
+}
+
+llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
+}
+
+llama_vocab::~llama_vocab() {
+}
+
+void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
+    pimpl->load(ml, kv);
+}
+
+enum llama_vocab_type llama_vocab::get_type() const {
+    return pimpl->type;
+}
+
+enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
+    return pimpl->pre_type;
+}
+
+uint32_t llama_vocab::n_tokens() const {
+    return (uint32_t) pimpl->id_to_token.size();
+}
+
+uint32_t llama_vocab::n_token_types() const {
+    return (uint32_t) pimpl->n_token_types;
+}
+
+std::string llama_vocab::type_name() const{
+    return pimpl->type_name();
+}
+
+bool llama_vocab::is_normal(llama_token id) const {
+    return pimpl->is_normal(id);
+}
+
+bool llama_vocab::is_unknown(llama_token id) const {
+    return pimpl->is_unknown(id);
+}
+
+bool llama_vocab::is_control(llama_token id) const {
+    return pimpl->is_control(id);
+}
+
+bool llama_vocab::is_byte(llama_token id) const {
+    return pimpl->is_byte(id);
+}
+
+bool llama_vocab::is_user_defined(llama_token id) const {
+    return pimpl->is_user_defined(id);
+}
+
+bool llama_vocab::is_unused(llama_token id) const {
+    return pimpl->is_unused(id);
+}
+
+bool llama_vocab::is_eog(llama_token id) const {
+    return pimpl->is_eog(id);
+}
+
+uint8_t llama_vocab::token_to_byte(llama_token id) const {
+    return pimpl->token_to_byte(id);
+}
+
+llama_token llama_vocab::byte_to_token(uint8_t ch) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    static const char * hex = "0123456789ABCDEF";
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
+            auto token = pimpl->token_to_id.find(buf);
+            if (token != pimpl->token_to_id.end()) {
+                return (*token).second;
+            }
+            // Try to fall back to just the byte as a string
+            const char buf2[2] = { (char)ch, 0 };
+            return pimpl->token_to_id.at(buf2);
+        }
+        case LLAMA_VOCAB_TYPE_WPM:
+        case LLAMA_VOCAB_TYPE_BPE: {
+            return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token llama_vocab::text_to_token(const std::string & text) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    auto it = pimpl->token_to_id.find(text);
+    if (it != pimpl->token_to_id.end()) {
+        return (*it).second;
+    }
+    return LLAMA_TOKEN_NULL;
+}
+
+const llama_vocab::token_data & llama_vocab::get_token_data(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id);
+}
+
+const char * llama_vocab::token_get_text(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).text.c_str();
+}
+
+float llama_vocab::token_get_score(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).score;
+}
+
+llama_token_attr llama_vocab::token_get_attr(llama_token id) const {
+    return pimpl->token_get_attr(id);
+}
+
+llama_token llama_vocab::token_bos() const {
+    return pimpl->special_bos_id;
+}
+
+llama_token llama_vocab::token_eos() const {
+    return pimpl->special_eos_id;
+}
+
+llama_token llama_vocab::token_eot() const {
+    return pimpl->special_eot_id;
+}
+
+llama_token llama_vocab::token_eom() const {
+    return pimpl->special_eom_id;
+}
+
+llama_token llama_vocab::token_unk() const {
+    return pimpl->special_unk_id;
+}
+
+llama_token llama_vocab::token_sep() const {
+    return pimpl->special_sep_id;
+}
+
+llama_token llama_vocab::token_nl() const {
+    return pimpl->linefeed_id;
+}
+
+llama_token llama_vocab::token_pad() const {
+    return pimpl->special_pad_id;
+}
+
+llama_token llama_vocab::token_prefix() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_middle() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_suffix() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_pre() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_fim_suf() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_mid() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_fim_pad() const {
+    return pimpl->special_fim_pad_id;
+}
+
+llama_token llama_vocab::token_fim_rep() const {
+    return pimpl->special_fim_rep_id;
+}
+
+llama_token llama_vocab::token_fim_sep() const {
+    return pimpl->special_fim_sep_id;
+}
+
+bool llama_vocab::get_add_space_prefix() const {
+    return pimpl->add_space_prefix;
+}
+
+bool llama_vocab::get_add_bos() const {
+    return pimpl->add_bos;
+}
+
+bool llama_vocab::get_add_eos() const {
+    return pimpl->add_eos;
+}
+
+bool llama_vocab::get_ignore_merges() const {
+    return pimpl->ignore_merges;
+}
+
+bool llama_vocab::get_clean_spaces() const {
+    return pimpl->clean_spaces;
+}
+
+bool llama_vocab::get_remove_extra_whitespaces() const {
+    return pimpl->remove_extra_whitespaces;
+}
+
+bool llama_vocab::get_escape_whitespaces() const {
+    return pimpl->escape_whitespaces;
+}
+
+bool llama_vocab::get_treat_whitespace_as_suffix() const {
+    return pimpl->treat_whitespace_as_suffix;
+}
+
+int llama_vocab::max_token_len() const {
+    return pimpl->max_token_len;
+}
+
+int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
+    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
+    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
+    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
+    GGML_ASSERT(token_right.find('\n') == std::string::npos);
+
+    auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right));
+    if (it == pimpl->bpe_ranks.end()) {
+        return -1;
+    }
+
+    return it->second;
+}
+
+int32_t llama_vocab::tokenize(
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) const {
+    auto res = tokenize(std::string(text, text_len), add_special, parse_special);
+    if (n_tokens_max < (int) res.size()) {
+        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
+        return -((int) res.size());
+    }
+
+    for (size_t i = 0; i < res.size(); i++) {
+        tokens[i] = res[i];
+    }
+
+    return res.size();
+}
+
+std::vector llama_vocab::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    return pimpl->tokenize(raw_text, add_special, parse_special);
+}
+
+const std::string & llama_vocab::token_to_piece(llama_token token) const {
+    return pimpl->token_to_piece(token);
+}
+
+int32_t llama_vocab::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
+    return pimpl->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_vocab::detokenize(
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) const {
+    return pimpl->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
+std::string llama_vocab::detokenize(const std::vector & tokens, bool special) const {
     std::string text;
     text.resize(std::max(text.capacity(), tokens.size()));
-    int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+    int32_t n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
     if (n_chars < 0) {
         text.resize(-n_chars);
-        n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+        n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
         GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
     }
 
@@ -1980,3 +3002,243 @@ std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector
     // NOTE: the original tokenizer decodes bytes after collecting the pieces.
     return text;
 }
+
+void llama_vocab::print_info() const {
+    pimpl->print_info();
+}
+
+//
+// interface implementation
+//
+
+int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab) {
+    return vocab->n_tokens();
+}
+
+// deprecated
+int32_t llama_n_vocab(const struct llama_vocab * vocab) {
+    return llama_vocab_n_tokens(vocab);
+}
+
+enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) {
+    return vocab->get_type();
+}
+
+const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_text(token);
+}
+
+float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_score(token);
+}
+
+enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_attr(token);
+}
+
+bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_eog(token);
+}
+
+bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_control(token);
+}
+
+llama_token llama_vocab_bos(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
+}
+
+llama_token llama_vocab_eos(const struct llama_vocab * vocab) {
+    return vocab->token_eos();
+}
+
+llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
+    return vocab->token_eot();
+}
+
+// deprecated
+llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
+}
+
+llama_token llama_vocab_sep(const struct llama_vocab * vocab) {
+    return vocab->token_sep();
+}
+
+llama_token llama_vocab_nl (const struct llama_vocab * vocab) {
+    return vocab->token_nl();
+}
+
+llama_token llama_vocab_pad(const struct llama_vocab * vocab) {
+    return vocab->token_pad();
+}
+
+bool llama_vocab_get_add_bos(const struct llama_vocab * vocab) {
+    return vocab->get_add_bos();
+}
+
+bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
+    return vocab->get_add_eos();
+}
+
+llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pre();
+}
+
+llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab) {
+    return vocab->token_fim_suf();
+}
+
+llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab) {
+    return vocab->token_fim_mid();
+}
+
+llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pad();
+}
+
+llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_rep();
+}
+
+llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_sep();
+}
+
+// deprecated
+const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_text(vocab, token);
+}
+
+// deprecated
+float llama_token_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_score(vocab, token);
+}
+
+// deprecated
+enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_attr(vocab, token);
+}
+
+// deprecated
+bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_eog(vocab, token);
+}
+
+// deprecated
+bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_control(vocab, token);
+}
+
+// deprecated
+llama_token llama_token_bos(const struct llama_vocab * vocab) {
+    return llama_vocab_bos(vocab);
+}
+
+// deprecated
+llama_token llama_token_eos(const struct llama_vocab * vocab) {
+    return llama_vocab_eos(vocab);
+}
+
+// deprecated
+llama_token llama_token_eot(const struct llama_vocab * vocab) {
+    return llama_vocab_eot(vocab);
+}
+
+// deprecated
+llama_token llama_token_cls(const struct llama_vocab * vocab) {
+    //return llama_vocab_cls(vocab);
+    return llama_vocab_bos(vocab); // avoid deprecation warning
+}
+
+// deprecated
+llama_token llama_token_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_sep(vocab);
+}
+
+// deprecated
+llama_token llama_token_nl (const struct llama_vocab * vocab) {
+    return llama_vocab_nl(vocab);
+}
+
+// deprecated
+llama_token llama_token_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_pad(vocab);
+}
+
+// deprecated
+bool llama_add_bos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_bos(vocab);
+}
+
+// deprecated
+bool llama_add_eos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_eos(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pre(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pre(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_suf(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_suf(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_mid(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_mid(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pad(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_rep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_rep(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_sep(vocab);
+}
+
+//
+// tokenization
+//
+
+int32_t llama_tokenize(
+    const struct llama_vocab * vocab,
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) {
+    return vocab->tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
+}
+
+int32_t llama_token_to_piece(
+    const struct llama_vocab * vocab,
+                 llama_token   token,
+                        char * buf,
+                     int32_t   length,
+                     int32_t   lstrip,
+                        bool   special) {
+    return vocab->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_detokenize(
+    const struct llama_vocab * vocab,
+           const llama_token * tokens,
+                     int32_t   n_tokens,
+                        char * text,
+                     int32_t   text_len_max,
+                        bool   remove_special,
+                        bool   unparse_special) {
+    return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h
index 0d00086d..5ce35521 100644
--- a/llama/llama.cpp/src/llama-vocab.h
+++ b/llama/llama.cpp/src/llama-vocab.h
@@ -4,179 +4,122 @@
 
 #include 
 #include 
-#include 
-#include 
-#include 
+#include 
 
-static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
-    switch (type) {
-        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
-        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
-        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
-        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
-        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
-        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
-        default:                    return "unknown";
-    }
-}
-
-struct llm_tokenizer;
+struct LLM_KV;
+struct llama_model_loader;
 
 struct llama_vocab {
-    using id    = llama_token;
-    using token = std::string;
-    using tattr = llama_token_attr;
-
     struct token_data {
-        token text;
-        float score;
-        tattr attr;
+        std::string      text;
+        float            score;
+        llama_token_attr attr;
     };
 
-    uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
-
-    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
-    enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
-
-    int max_token_len = 0; // used for optimizing longest token search
-
-    std::unordered_map token_to_id;
-    std::vector       id_to_token;
-
-    std::vector    cache_special_tokens;
-    std::vector cache_token_to_piece; // llama_token_to_piece(special = true);
-
-    std::map, int> bpe_ranks;
-
-    // default LLaMA special tokens
-    // TODO: should we set all of these to LLAMA_TOKEN_NULL?
-    id special_bos_id  = 1;
-    id special_eos_id  = 2;
-    id special_eot_id  = LLAMA_TOKEN_NULL;
-    id special_eom_id  = LLAMA_TOKEN_NULL;
-    id special_unk_id  = 0;
-    id special_sep_id  = LLAMA_TOKEN_NULL;
-    id special_pad_id  = LLAMA_TOKEN_NULL;
-    id special_cls_id  = LLAMA_TOKEN_NULL; // TODO: revisit if this is really needed https://github.com/ggerganov/llama.cpp/pull/10930
-    id special_mask_id = LLAMA_TOKEN_NULL;
-
-    id linefeed_id = 13;
-
-    // fim tokens
-    id special_fim_pre_id = LLAMA_TOKEN_NULL;
-    id special_fim_suf_id = LLAMA_TOKEN_NULL;
-    id special_fim_mid_id = LLAMA_TOKEN_NULL;
-    id special_fim_pad_id = LLAMA_TOKEN_NULL;
-    id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
-    id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
-
-    // set of all tokens that cause "end of generation"
-    std::set special_eog_ids;
-
-    // tokenizer flags
-    bool tokenizer_add_space_prefix           = false;
-    bool tokenizer_add_bos                    = false;
-    bool tokenizer_add_eos                    = false;
-    bool tokenizer_ignore_merges              = false;
-    bool tokenizer_clean_spaces               = false;  // clean_up_tokenization_spaces
-    bool tokenizer_remove_extra_whitespaces   = false;
-    bool tokenizer_escape_whitespaces         = true;
-    bool tokenizer_treat_whitespace_as_suffix = false;
-
-    std::vector precompiled_charsmap;
-
-    llm_tokenizer * tokenizer = nullptr;
-
-    llama_vocab() = default;
+    llama_vocab();
     ~llama_vocab();
 
+    void load(llama_model_loader & ml, const LLM_KV & kv);
+
+    enum llama_vocab_type     get_type()     const;
+    enum llama_vocab_pre_type get_pre_type() const;
+
+    uint32_t n_tokens() const;
+    uint32_t n_token_types() const;
+
+    std::string type_name() const;
+
+    bool is_normal      (llama_token id) const;
+    bool is_unknown     (llama_token id) const;
+    bool is_control     (llama_token id) const;
+    bool is_byte        (llama_token id) const;
+    bool is_user_defined(llama_token id) const;
+    bool is_unused      (llama_token id) const;
+    bool is_eog         (llama_token id) const;
+
+    uint8_t     token_to_byte(llama_token id) const;
+    llama_token byte_to_token(uint8_t ch)     const;
+
+    llama_token text_to_token(const std::string & text) const;
+
+    const token_data & get_token_data(llama_token id) const;
+
+    const char *     token_get_text (llama_token id) const;
+    float            token_get_score(llama_token id) const;
+    llama_token_attr token_get_attr (llama_token id) const;
+
+    llama_token token_bos() const;
+    llama_token token_eos() const;
+    llama_token token_eot() const;
+    llama_token token_eom() const;
+    llama_token token_unk() const;
+    llama_token token_sep() const;
+    llama_token token_nl () const;
+    llama_token token_pad() const;
+
+    llama_token token_prefix() const;
+    llama_token token_middle() const;
+    llama_token token_suffix() const;
+
+    llama_token token_fim_pre() const;
+    llama_token token_fim_suf() const;
+    llama_token token_fim_mid() const;
+    llama_token token_fim_pad() const;
+    llama_token token_fim_rep() const;
+    llama_token token_fim_sep() const;
+
+    bool get_add_space_prefix          () const;
+    bool get_add_bos                   () const;
+    bool get_add_eos                   () const;
+    bool get_ignore_merges             () const;
+    bool get_clean_spaces              () const;
+    bool get_remove_extra_whitespaces  () const;
+    bool get_escape_whitespaces        () const;
+    bool get_treat_whitespace_as_suffix() const;
+
+    int max_token_len() const;
+
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
 
-    void init_tokenizer();
+    int32_t tokenize(
+                   const char * text,
+                      int32_t   text_len,
+                  llama_token * tokens,
+                      int32_t   n_tokens_max,
+                         bool   add_special,
+                         bool   parse_special) const;
+
+    std::vector tokenize(
+            const std::string & raw_text,
+                         bool   add_special,
+                         bool   parse_special = false) const;
+
+    // does not write null-terminator to buf
+    int32_t token_to_piece(
+                  llama_token   token,
+                         char * buf,
+                      int32_t   length,
+                      int32_t   lstrip,
+                         bool   special) const;
+
+    // use cached data
+    const std::string & token_to_piece(llama_token token) const;
+
+    int32_t detokenize(
+            const llama_token * tokens,
+                      int32_t   n_tokens,
+                         char * text,
+                      int32_t   text_len_max,
+                         bool   remove_special,
+                         bool   unparse_special) const;
+
+    std::string detokenize(
+            const std::vector & tokens,
+                                      bool   special) const;
+
+    void print_info() const;
+
+private:
+    struct impl;
+    std::unique_ptr pimpl;
 };
-
-//
-// internal API
-//
-
-// TODO: rename to llama_tokenize_impl
-// TODO: This should probably be in llama.h
-std::vector llama_tokenize_internal(
-        const llama_vocab & vocab,
-        std::string raw_text,
-        bool add_special,
-        bool parse_special = false);
-
-// TODO: move the API below as member functions of llama_vocab
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
-
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
-
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
-
-bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
-bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
-
-int32_t llama_tokenize_impl(
-        const struct llama_vocab & vocab,
-                      const char * text,
-                         int32_t   text_len,
-                     llama_token * tokens,
-                         int32_t   n_tokens_max,
-                            bool   add_special,
-                            bool   parse_special);
-
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token,
-                            char * buf,
-                         int32_t   length,
-                         int32_t   lstrip,
-                            bool   special);
-
-// check if token0 is contained as a prefix in token1
-bool llama_token_is_prefix_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token0,
-                     llama_token   token1);
-
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
-               const llama_token * tokens,
-                         int32_t   n_tokens,
-                            char * text,
-                         int32_t   text_len_max,
-                            bool   remove_special,
-                            bool   unparse_special);
-
-std::string llama_detokenize(
-        const struct llama_vocab & vocab,
-  const std::vector & tokens,
-                            bool   special);
diff --git a/llama/llama.cpp/src/llama.cpp b/llama/llama.cpp/src/llama.cpp
index c95da45d..01854fce 100644
--- a/llama/llama.cpp/src/llama.cpp
+++ b/llama/llama.cpp/src/llama.cpp
@@ -8,7 +8,6 @@
 #include "llama-kv-cache.h"
 #include "llama-model-loader.h"
 #include "llama-model.h"
-#include "llama-quant.h"
 
 #include "ggml.h"
 #include "ggml-alloc.h"
@@ -18,2560 +17,60 @@
 #include 
 #include 
 #include 
-#include 
 #include 
-#include 
-#include 
 #include 
-#include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
-#include 
-#include 
-#include 
-#include 
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-//
-// tensor loading (TODO: add llama_tesor_loader?)
-//
-
-static int llama_get_device_count(const llama_model & model) {
-    return (int) model.devices.size();
-}
-
-// checks if the weight tensor can be used with the specified buffer type and device
-static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
-    GGML_ASSERT(w != nullptr);
-
-    if (op == GGML_OP_NONE) {
-        return true;
-    }
-
-    ggml_init_params params = {
-        /*.mem_size   =*/ ggml_tensor_overhead()*8,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-    ggml_context_ptr ctx_ptr { ggml_init(params) };
-    if (!ctx_ptr) {
-        throw std::runtime_error(format("failed to create ggml context"));
-    }
-    ggml_context * ctx = ctx_ptr.get();
-
-    ggml_tensor * op_tensor = nullptr;
-
-    switch (op) {
-        case GGML_OP_GET_ROWS:
-            {
-                ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
-                op_tensor = ggml_get_rows(ctx, w, b);
-            } break;
-        case GGML_OP_MUL_MAT:
-            {
-                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]);
-                op_tensor = ggml_mul_mat(ctx, w, b);
-            } break;
-        case GGML_OP_MUL_MAT_ID:
-            {
-                int n_expert_used = hparams.n_expert_used;
-                ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512);
-                ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512);
-                op_tensor = ggml_mul_mat_id(ctx, w, b, ids);
-            } break;
-        case GGML_OP_ADD:
-            {
-                ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
-                op_tensor = ggml_add(ctx, a, w);
-            } break;
-        case GGML_OP_MUL:
-            {
-                ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
-                op_tensor = ggml_mul(ctx, a, w);
-            } break;
-        case GGML_OP_DIV:
-            {
-                ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]);
-                op_tensor = ggml_div(ctx, a, w);
-            } break;
-        case GGML_OP_ROPE:
-            {
-                int n_embd_head = hparams.n_embd_head_v;
-                int n_head = hparams.n_head();
-                ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512);
-                ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
-                op_tensor = ggml_rope_ext(
-                    ctx, a, b, w,
-                    0, 0, 0, 0, 0,
-                    0, 0, 0, 0
-                );
-
-            } break;
-        case GGML_OP_SSM_CONV:
-            {
-                // FIXME
-                ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789);
-                op_tensor = ggml_ssm_conv(ctx, conv_x, w);
-            } break;
-        case GGML_OP_SSM_SCAN:
-            {
-                // FIXME
-                const int64_t d_state      = w->ne[0];
-                const int64_t d_inner      = w->ne[1];
-                const int64_t n_seq_tokens = 512;
-                const int64_t n_seqs       = 1;
-                ggml_tensor * s  = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs);
-                ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
-                ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
-                ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
-                ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
-                op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
-            } break;
-        case GGML_OP_RWKV_WKV6:
-            {
-                // FIXME
-                const int64_t S = 123;
-                const int64_t H = 123;
-                const int64_t n_tokens = 123;
-                const int64_t n_seqs = 123;
-                ggml_tensor  * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
-                ggml_tensor  * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
-                ggml_tensor  * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
-                ggml_tensor  * tf = w;
-                ggml_tensor  * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
-                ggml_tensor  * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
-                op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
-            } break;
-        case GGML_OP_IM2COL:
-            {
-                const int n_embd = hparams.n_embd;
-                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
-                op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
-            } break;
-        default:
-            GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
-    }
-
-    // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
-    GGML_ASSERT(w->buffer == nullptr);
-    w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
-    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
-    ggml_backend_buffer_free(w->buffer);
-    w->buffer = nullptr;
-
-    return op_supported;
-}
-
-// find the first buffer type in the list that can use the tensor
-static ggml_backend_buffer_type_t select_weight_buft(const llama_model & model, ggml_tensor * tensor, ggml_op op, const llama_model::buft_list_t & buft_list) {
-    GGML_ASSERT(!buft_list.empty());
-    for (const auto & cur : buft_list) {
-        ggml_backend_dev_t cur_dev = cur.first;
-        ggml_backend_buffer_type_t cur_buft = cur.second;
-        if (weight_buft_supported(model.hparams, tensor, op, cur_buft, cur_dev)) {
-            return cur_buft;
-        }
-    }
-    return nullptr;
-}
-
-// CPU: ACCEL -> CPU extra -> GPU host -> CPU
-static llama_model::buft_list_t make_cpu_buft_list(llama_model & model) {
-    llama_model::buft_list_t buft_list;
-
-    // add ACCEL buffer types
-    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
-            auto * buft = ggml_backend_dev_buffer_type(dev);
-            // skip
-            if (buft != ggml_backend_cpu_buffer_type()) {
-                buft_list.emplace_back(dev, buft);
-            }
-        }
-    }
-
-    // add extra buffer types
-    auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-    auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
-    auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
-        ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
-    if (ggml_backend_dev_get_extra_bufts_fn) {
-        ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
-        while (extra_bufts && *extra_bufts) {
-            buft_list.emplace_back(cpu_dev, *extra_bufts);
-            ++extra_bufts;
-        }
-    }
-
-    // add a host buffer type
-    // storing the tensors in a host buffer is useful when the processing of large batches
-    // is offloaded to a GPU device, since it reduces the time spent on data transfers
-    // generally, this will be done using the first device in the list
-    // a better approach would be to handle this on a weight-by-weight basis using the offload_op
-    // function of the device to determine if it would benefit from being stored in a host buffer
-    for (auto * dev : model.devices) {
-        ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev);
-        if (buft) {
-            buft_list.emplace_back(dev, buft);
-            break;
-        }
-    }
-
-    // add the CPU buffer type
-    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
-            buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
-        }
-    }
-
-    return buft_list;
-}
-
-// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU
-static llama_model::buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) {
-    llama_model::buft_list_t buft_list;
-
-    // add the device split buffer type if requested and available
-    if (split_mode == LLAMA_SPLIT_MODE_ROW) {
-        ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
-        auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t)
-            ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type");
-        if (ggml_backend_split_buffer_type_fn) {
-            size_t dev_index = [&]() {
-                auto * reg = ggml_backend_dev_backend_reg(dev);
-                for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) {
-                    if (ggml_backend_reg_dev_get(reg, i) == dev) {
-                        return i;
-                    }
-                }
-                throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev)));
-            }();
-            auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split);
-            if (buft != nullptr) {
-                buft_list.emplace_back(dev, buft);
-            }
-        }
-    }
-
-    // add the device default buffer type
-    buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
-
-    return buft_list;
-}
-
-// Returns false if cancelled by progress_callback
-static bool llm_load_tensors(
-        llama_model_loader & ml,
-        llama_model & model,
-        int n_gpu_layers,
-        enum llama_split_mode split_mode,
-        int main_gpu,
-        const float * tensor_split,
-        bool use_mlock,
-        llama_progress_callback progress_callback,
-        void * progress_callback_user_data) {
-    auto & hparams = model.hparams;
-
-    model.split_mode   = split_mode;
-    model.main_gpu     = main_gpu;
-    model.n_gpu_layers = n_gpu_layers;
-
-    const int n_layer = hparams.n_layer;
-
-    bool use_mmap_buffer = true;
-
-    // build a list of buffer types for the CPU and GPU devices
-    model.cpu_buft_list = make_cpu_buft_list(model);
-    for (auto * dev : model.devices) {
-        llama_model::buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split);
-        // add CPU buffer types as a fallback
-        buft_list.insert(buft_list.end(), model.cpu_buft_list.begin(), model.cpu_buft_list.end());
-        model.gpu_buft_list.emplace(dev, std::move(buft_list));
-    }
-
-    // calculate the split points
-    int device_count = llama_get_device_count(model);
-    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
-    std::vector splits(device_count);
-    if (all_zero) {
-        // default split, by free memory
-        for (int i = 0; i < device_count; ++i) {
-            ggml_backend_dev_t dev = model.devices[i];
-            size_t total;
-            size_t free;
-            ggml_backend_dev_memory(dev, &free, &total);
-            splits[i] = free;
-        }
-    } else {
-        std::copy(tensor_split, tensor_split + device_count, splits.begin());
-    }
-
-    // sum and normalize the splits to get the split points
-    float split_sum = 0.0f;
-    for (int i = 0; i < device_count; ++i) {
-        split_sum += splits[i];
-        splits[i] = split_sum;
-    }
-    for (int i = 0; i < device_count; ++i) {
-        splits[i] /= split_sum;
-    }
-
-    ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-    const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
-    const int act_gpu_layers = model.devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
-    auto get_layer_buft_list = [&](int il) -> llama_model::layer_dev {
-        if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
-            return {cpu_dev, &model.cpu_buft_list};
-        }
-        int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(il - i_gpu_start)/act_gpu_layers) - splits.begin();
-        auto * dev = model.devices.at(layer_gpu);
-        return {dev, &model.gpu_buft_list.at(dev)};
-    };
-
-    // assign the input layer
-    // there is very little benefit to offloading the input layer, so always keep it on the CPU
-    model.dev_input = { cpu_dev, &model.cpu_buft_list };
-
-    // assign the repeating layers to the devices according to the splits
-    model.dev_layer.resize(n_layer);
-    for (int il = 0; il < n_layer; ++il) {
-        model.dev_layer[il] = get_layer_buft_list(il);
-    }
-    // assign the output layer
-    model.dev_output = get_layer_buft_list(n_layer);
-
-    // one ggml context per buffer type
-    int max_n_tensors = ml.n_tensors;
-    max_n_tensors += 1;         // duplicated output tensor
-    max_n_tensors += n_layer*2; // duplicated rope freq tensors
-    const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
-
-    std::map ctx_map;
-    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
-        auto it = ctx_map.find(buft);
-        if (it == ctx_map.end()) {
-            ggml_init_params params = {
-                /*.mem_size   =*/ ctx_size,
-                /*.mem_buffer =*/ NULL,
-                /*.no_alloc   =*/ true,
-            };
-            ggml_context * ctx = ggml_init(params);
-            if (!ctx) {
-                throw std::runtime_error(format("failed to create ggml context"));
-            }
-            ctx_map[buft] = ctx;
-            model.ctxs.emplace_back(ctx);
-            return ctx;
-        }
-        return it->second;
-    };
-
-    // create tensors for the weights
-    {
-        // note: cast to int64_t since we will use these for the tensor dimensions
-        const int64_t n_head        = hparams.n_head();
-        const int64_t n_head_kv     = hparams.n_head_kv();
-        const int64_t n_embd        = hparams.n_embd;
-        const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
-        const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
-        const int64_t n_embd_head_v = hparams.n_embd_head_v;
-        const int64_t n_ff          = hparams.n_ff();
-        const int64_t n_embd_gqa    = n_embd_v_gqa;
-        const int64_t n_vocab       = hparams.n_vocab;
-        const int64_t n_vocab_type  = hparams.n_vocab_type;
-        const int64_t n_rot         = hparams.n_rot;
-        const int64_t n_expert      = hparams.n_expert;
-        const int64_t n_expert_used = hparams.n_expert_used;
-        const int64_t n_ctx_train   = hparams.n_ctx_train;
-
-        if (n_expert > 0 && hparams.n_expert_used == 0) {
-            throw std::runtime_error("model has expert layers but no expert layers are used");
-        }
-
-        int n_moved_tensors = 0;
-        ggml_tensor * first_moved_tensor = nullptr;
-        ggml_backend_buffer_type_t first_moved_from_buft = nullptr;
-        ggml_backend_buffer_type_t first_moved_to_buft = nullptr;
-
-        auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * {
-            ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str());
-
-            if (!t_meta) {
-                if (flags & llama_model_loader::TENSOR_NOT_REQUIRED) {
-                    return nullptr;
-                }
-                throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str()));
-            }
-
-            // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops
-            // the tensor is duplicated
-            // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor
-            llm_tensor tn_tensor = tn.tensor;
-            if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & llama_model_loader::TENSOR_DUPLICATED) {
-                tn_tensor = LLM_TENSOR_OUTPUT;
-            }
-
-            llm_tensor_info info;
-            try {
-                info = llm_tensor_info_for(tn_tensor);
-            } catch (const std::out_of_range & e) {
-                throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str()));
-            }
-
-            // tensors with "bias" suffix are always used with GGML_OP_ADD
-            ggml_op op;
-            bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
-            if (bias) {
-                op = GGML_OP_ADD;
-            } else {
-                op = info.op;
-            }
-
-            // sanity checks
-            if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) {
-                if (tn.bid != -1) {
-                    GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());
-                }
-            } else {
-                if (tn.bid == -1) {
-                    GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str());
-                }
-            }
-
-            // select the buffer type for this tensor
-            llama_model::buft_list_t * buft_list;
-            switch (info.layer) {
-                case LLM_TENSOR_LAYER_INPUT:
-                    buft_list = model.dev_input.buft_list;
-                    break;
-                case LLM_TENSOR_LAYER_OUTPUT:
-                    buft_list = model.dev_output.buft_list;
-                    break;
-                case LLM_TENSOR_LAYER_REPEATING:
-                    buft_list = model.dev_layer.at(tn.bid).buft_list;
-                    break;
-                default:
-                    GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
-            }
-
-            ggml_backend_buffer_type_t buft = select_weight_buft(model, t_meta, op, *buft_list);
-            if (!buft) {
-                throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
-            }
-
-            // avoid using a host buffer when using mmap
-            auto * buft_dev = ggml_backend_buft_get_device(buft);
-            if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
-                auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-                buft = ggml_backend_dev_buffer_type(cpu_dev);
-            }
-
-            if (buft != buft_list->front().second) {
-                n_moved_tensors++;
-                if (!first_moved_tensor) {
-                    first_moved_tensor = t_meta;
-                    first_moved_from_buft = buft_list->front().second;
-                    first_moved_to_buft   = buft;
-                }
-            }
-
-            ggml_context * ctx = ctx_for_buft(buft);
-
-            // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one
-            if (flags & llama_model_loader::TENSOR_DUPLICATED) {
-                ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str());
-                if (t) {
-                    return t;
-                }
-            }
-            return ml.create_tensor(ctx, tn, ne, flags);
-        };
-
-        model.layers.resize(n_layer);
-
-        // TODO: move to a separate function
-        const auto tn = LLM_TN(model.arch);
-        switch (model.arch) {
-            case LLM_ARCH_LLAMA:
-            case LLM_ARCH_REFACT:
-            case LLM_ARCH_MINICPM:
-            case LLM_ARCH_GRANITE:
-            case LLM_ARCH_GRANITE_MOE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
-                            layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                            layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-                        else {
-                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-
-                        if (n_expert == 0) {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                            // optional MLP bias
-                            layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        } else {
-                            layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_MLLAMA:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0);
-
-                    // output
-                    {
-                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // if output is NULL, init from the input tok embed
-                        if (model.output == NULL) {
-                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        if (hparams.cross_attention_layers(i)) {
-                            layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128}, 0);
-                            layer.cross_attn_k_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_PROJ,   "weight", i), {n_embd, 1024}, 0);
-                            layer.cross_attn_o_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_O_PROJ,   "weight", i), {n_embd, n_embd}, 0);
-                            layer.cross_attn_q_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_NORM, "weight", i), {128}, 0);
-                            layer.cross_attn_q_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_Q_PROJ, "weight", i), {n_embd, n_embd}, 0);
-                            layer.cross_attn_v_proj = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_V_PROJ, "weight", i), {n_embd, 1024}, 0);
-                            layer.cross_attn_attn_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_ATTN_GATE, i), {1}, 0);
-                            layer.cross_attn_mlp_gate = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_MLP_GATE, i), {1}, 0);
-                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        } else {
-                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-                            layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_DECI:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-                        const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa(i);
-                        const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa(i);
-                        const int64_t n_embd_gqa    = hparams.n_embd_v_gqa(i);
-                        const int64_t n_ff          = hparams.n_ff(i);
-                        const int64_t n_head        = hparams.n_head(i);
-                        const int64_t n_head_kv     = hparams.n_head_kv(i);
-
-                        if (n_head_kv == 0 && n_head > 0) {
-                            // linear attention for DeciLMCausalModel
-                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        }
-                        else if (n_head_kv > 0) {
-                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-                        }
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
-                            layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                            layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-                        else {
-                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        // optional MLP bias
-                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_MINICPM3:
-                {
-                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
-                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-
-                    const int64_t q_lora_rank  = hparams.n_lora_q;
-                    const int64_t kv_lora_rank = hparams.n_lora_kv;
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
-
-                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
-
-                        layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
-                        layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
-
-                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
-                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
-                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                    }
-                } break;
-            case LLM_ARCH_GROK:
-                {
-                    if (n_expert == 0) {
-                        throw std::runtime_error("Grok model cannot have zero experts");
-                    }
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
-
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_DBRX:
-                {
-                    if (n_expert == 0) {
-                        throw std::runtime_error("DBRX model cannot have zero experts");
-                    }
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                    }
-                } break;
-            case LLM_ARCH_BAICHUAN:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    {
-                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_FALCON:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    {
-                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-
-                        model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        if (!model.output) {
-                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_STARCODER:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, 0);
-
-                    // output
-                    {
-                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                        model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        if (!model.output) {
-                            // needs to be on GPU
-                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff}, 0);
-                        layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_BERT:
-            case LLM_ARCH_NOMIC_BERT:
-                {
-                    model.tok_embd     = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
-                    model.type_embd    = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0);
-
-                    if (model.arch == LLM_ARCH_BERT) {
-                        model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, n_ctx_train}, 0);
-
-                        model.cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        model.cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {n_embd},         llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        model.cls_out   = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        model.cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {1},         llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        if (model.arch == LLM_ARCH_BERT) {
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd}, 0);
-
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa}, 0);
-
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa}, 0);
-                        } else {
-                            layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        }
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd}, 0);
-
-                        if (model.arch == LLM_ARCH_BERT) {
-                            layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
-                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, 0);
-                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
-                        } else {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-                        }
-
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
-                        layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_JINA_BERT_V2:
-                {
-                    model.tok_embd  = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0); // word_embeddings
-                    model.type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0); // token_type_embeddings
-
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0); //LayerNorm bias
-
-                    model.cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    model.cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {1},         llama_model_loader::TENSOR_NOT_REQUIRED);
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i]; // JinaBertLayer
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd}, 0);
-
-                        layer.attn_q_norm   = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa}, 0);
-
-                        layer.attn_k_norm   = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}, 0); //output_dens
-
-                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm
-                        layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias",   i), {n_embd}, 0);
-
-                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
-
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
-                        layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias",   i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_BLOOM:
-                {
-                    model.tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias",   i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias",   i), {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias",   i), {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias",   i), {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_MPT:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    if (!model.output) {
-                        model.output    = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.attn_q_norm   = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.attn_k_norm   = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // AWQ ScaleActivation layer
-                        layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_STABLELM:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm =   create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors, present in Stable LM 2 1.6B
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // optional q and k layernorms, present in StableLM 2 12B
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head},    llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // optional FFN norm, not present in StableLM 2 12B which uses parallel residual
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_QWEN:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd*3}, 0);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff/2}, 0);
-                    }
-                } break;
-            case LLM_ARCH_QWEN2:
-            case LLM_ARCH_QWEN2VL:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_QWEN2MOE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-
-                        if (n_expert == 0) {
-                            throw std::runtime_error("n_expert must be > 0 for QWEN2MOE");
-                        }
-                        if (n_expert_used == 0) {
-                            throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE");
-                        }
-
-                        // MoE branch
-                        const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
-
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-
-                        // Shared expert branch
-                        const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
-
-                        layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {    n_embd, n_ff_shexp}, 0);
-                        layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp,     n_embd}, 0);
-                        layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {    n_embd, n_ff_shexp}, 0);
-                    }
-                } break;
-            case LLM_ARCH_PHI2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-                    model.output_b      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "bias"),   {n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        if (layer.wqkv == nullptr) {
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
-                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd}, 0);
-
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i),   {n_embd_gqa}, 0);
-
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i),   {n_embd_gqa}, 0);
-                        }
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_PHI3:
-                {
-                    const int64_t n_embd_head = n_embd / n_head;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
-                        layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
-
-                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                    }
-                } break;
-            case LLM_ARCH_PLAMO:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GPT2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_CODESHELL:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_ORION:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_INTERNLM2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GEMMA:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GEMMA2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_STARCODER2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        // optional bias tensors
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP ,  "bias", i), {  n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_MAMBA:
-                {
-                    const int64_t d_conv  = hparams.ssm_d_conv;
-                    const int64_t d_inner = hparams.ssm_d_inner;
-                    const int64_t d_state = hparams.ssm_d_state;
-                    const int64_t dt_rank = hparams.ssm_dt_rank;
-
-                    // only an expansion factor of 2 is supported for now
-                    if (2 * n_embd != d_inner) {
-                        throw std::runtime_error("only an expansion factor of 2 is supported for now");
-                    }
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed, duplicated to allow offloading
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        // norm
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0);
-
-                        layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0);
-                        layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0);
-
-                        layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0);
-
-                        layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0);
-                        layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0);
-
-                        // no "weight" suffix for these
-                        layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
-                        layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
-
-                        // out_proj
-                        layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_XVERSE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_COMMAND_R:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    // init output from the input tok embed
-                    model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (n_layer >= 64){
-                            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
-                            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
-                        }
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_COHERE2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
-                    // init output from the input tok embed
-                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
-                                                      llama_model_loader::TENSOR_DUPLICATED);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
-                    }
-                }
-                break;
-            case LLM_ARCH_OLMO:  // adapted from LLM_ARCH_LLAMA with norm params removed
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_OLMO2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_OLMOE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-
-                        if (n_expert == 0) {
-                            throw std::runtime_error("n_expert must be > 0");
-                        }
-                        if (n_expert_used == 0) {
-                            throw std::runtime_error("n_expert_used must be > 0");
-                        }
-
-                        // MoE branch
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                    }
-                } break;
-            case LLM_ARCH_OPENELM:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    // init output from the input tok embed
-                    model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        const int64_t n_head      =   hparams.n_head(i);
-                        const int64_t n_head_qkv  = 2*hparams.n_head_kv(i) + n_head;
-                        const int64_t n_ff        =   hparams.n_ff(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GPTNEOX:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_ARCTIC:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
-                    }
-                } break;
-            case LLM_ARCH_DEEPSEEK:
-                {
-
-                    const int64_t n_ff_exp        = hparams.n_ff_exp;
-                    const int64_t n_expert_shared = hparams.n_expert_shared;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (i < (int) hparams.n_layer_dense_lead) {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        } else {
-                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-
-                            if (n_expert == 0) {
-                                throw std::runtime_error("n_expert must be > 0");
-                            }
-                            if (n_expert_used == 0) {
-                                throw std::runtime_error("n_expert_used must be > 0");
-                            }
-
-                            // MoE branch
-                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
-                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-
-                            // Shared expert branch
-                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
-                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_DEEPSEEK2:
-                {
-                    const bool is_lite = (hparams.n_layer == 27);
-
-                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
-                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-
-                    const int64_t q_lora_rank  = hparams.n_lora_q;
-                    const int64_t kv_lora_rank = hparams.n_lora_kv;
-
-                    const int64_t n_ff_exp        = hparams.n_ff_exp;
-                    const int64_t n_expert_shared = hparams.n_expert_shared;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        if (!is_lite) {
-                            layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
-                        }
-
-                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
-
-                        if (!is_lite) {
-                            layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
-                            layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
-                        } else {
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        }
-
-                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
-                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
-                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (i < (int) hparams.n_layer_dense_lead) {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        } else {
-                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                            if (n_expert == 0) {
-                                throw std::runtime_error("n_expert must be > 0");
-                            }
-                            if (n_expert_used == 0) {
-                                throw std::runtime_error("n_expert_used must be > 0");
-                            }
-
-                            // MoE branch
-                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
-                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-
-                            // Shared expert branch
-                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
-                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_BITNET:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm     = create_tensor(tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd}, 0);
-                        layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq       = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wk       = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wv       = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wo       = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm     = create_tensor(tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd}, 0);
-                        layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0);
-
-                        layer.ffn_gate       = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down       = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up         = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_scale   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_T5:
-                {
-                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm     = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0);
-
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        layer.attn_norm  = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.attn_norm_cross  = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        // this tensor seems to be unused in HF transformers implementation
-                        layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_T5ENCODER:
-                {
-                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_JAIS:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_CHATGLM:
-                {
-                    model.tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_NEMOTRON:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        // optional MLP bias
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_EXAONE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM,   "weight", i), {n_embd}, 0);
-                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN,   "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,     "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_RWKV6:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // Block 0, LN0
-                    model.tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
-
-                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
-                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
-                    const int head_size = hparams.wkv_head_size;
-                    const int attn_hidden_size = n_embd;
-                    const int ffn_size = hparams.n_ff_arr[0];
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, 0);
-
-                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
-                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
-
-                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
-
-                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
-                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
-                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
-                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
-                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
-                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
-                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
-                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
-
-                        layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
-                        layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
-                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
-
-                        layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
-
-                        layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
-                        layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
-                        layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0);
-                    }
-
-                } break;
-            case LLM_ARCH_CHAMELEON:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
-                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i),  {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i),  {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_SOLAR:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    {
-                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.bskcn_tv = create_tensor(tn(LLM_TENSOR_BSKCN_TV, "weight", i), {2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_WAVTOKENIZER_DEC:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0);
-
-                    model.conv1d   = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0);
-                    model.conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"),   {1, hparams.posnet.n_embd}, 0);
-
-                    // posnet
-                    {
-                        const int64_t n_embd = hparams.posnet.n_embd;
-
-                        for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) {
-                            auto & layer = model.layers[i].posnet;
-
-                            // posnet:
-                            //
-                            //  - resnet
-                            //  - resnet
-                            //  - attn
-                            //  - resnet
-                            //  - resnet
-                            //  - norm
-                            //
-                            switch (i) {
-                                case 0:
-                                case 1:
-                                case 3:
-                                case 4:
-                                    {
-                                        layer.norm1   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0);
-                                        layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.conv1   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0);
-                                        layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.norm2   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0);
-                                        layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.conv2   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0);
-                                        layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias",   i), {1, n_embd}, 0);
-                                    } break;
-                                case 2:
-                                    {
-                                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
-                                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_q      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_q_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_k      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_k_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_v      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_v_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_o      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_o_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "bias",   i), {1, n_embd}, 0);
-                                    } break;
-                                case 5:
-                                    {
-                                        layer.norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
-                                        layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
-                                    } break;
-                                default: GGML_ABORT("unknown posnet layer");
-                            };
-                        }
-                    }
-
-                    GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);
-
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {hparams.posnet.n_embd}, 0);
-
-                    // convnext
-                    {
-                        const int64_t n_embd = hparams.convnext.n_embd;
-
-                        for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) {
-                            auto & layer = model.layers[i].convnext;
-
-                            layer.dw     = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "weight", i), {7, 1, n_embd}, 0);
-                            layer.dw_b   = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "bias",   i), {1, n_embd}, 0);
-
-                            layer.norm   = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "weight", i), {n_embd}, 0);
-                            layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "bias",   i), {n_embd}, 0);
-
-                            layer.pw1    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "weight", i), {n_embd, n_ff}, 0);
-                            layer.pw1_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "bias",   i), {n_ff}, 0);
-
-                            layer.pw2    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "weight", i), {n_ff, n_embd}, 0);
-                            layer.pw2_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "bias",   i), {n_embd}, 0);
-
-                            layer.gamma  = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0);
-                        }
-
-                        // output
-                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    }
-
-                    model.output   = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
-                    model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"),   {n_embd}, 0);
-                } break;
-            default:
-                throw std::runtime_error("unknown architecture");
-        }
-
-        if (n_moved_tensors > 0) {
-            LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n",
-                __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1,
-                ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft));
-        }
-    }
-
-    ml.done_getting_tensors();
-
-    ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr);
-    model.mappings.reserve(ml.mappings.size());
-
-    // create the backend buffers
-    std::vector> ctx_bufs;
-    ctx_bufs.reserve(ctx_map.size());
-
-    // Ensure we have enough capacity for the maximum backend buffer we will potentially create
-    const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
-    model.bufs.reserve(n_max_backend_buffer);
-
-    for (auto & it : ctx_map) {
-        ggml_backend_buffer_type_t buft = it.first;
-        ggml_context * ctx              = it.second;
-
-        // skip contexts without tensors
-        if (ggml_get_first_tensor(ctx) == nullptr) {
-            continue;
-        }
-
-        llama_buf_map bufs;
-        bufs.reserve(n_max_backend_buffer);
-
-        // check if it is possible to use buffer_from_host_ptr with this buffer type
-        ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
-        if (!dev) {
-            // FIXME: workaround for CPU backend buft having a NULL device
-            dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-        }
-        ggml_backend_dev_props props;
-        ggml_backend_dev_get_props(dev, &props);
-        bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
-        bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
-
-        if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                // only the mmap region containing the tensors in the model is mapped to the backend buffer
-                // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
-                // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
-                void * addr = nullptr;
-                size_t first, last; // NOLINT
-                ml.get_mapping_range(&first, &last, &addr, idx, ctx);
-                if (first >= last) {
-                    continue;
-                }
-                const size_t max_size = ggml_get_max_tensor_size(ctx);
-                ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
-                if (buf == nullptr) {
-                    throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
-                }
-                model.bufs.emplace_back(buf);
-                bufs.emplace(idx, buf);
-            }
-        }
-        else {
-            ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-            if (buf == nullptr) {
-                throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
-            }
-            model.bufs.emplace_back(buf);
-            if (use_mlock && ggml_backend_buffer_is_host(buf)) {
-                model.mlock_bufs.emplace_back(new llama_mlock);
-                auto & mlock_buf = model.mlock_bufs.back();
-                mlock_buf->init   (ggml_backend_buffer_get_base(buf));
-                mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
-            }
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                bufs.emplace(idx, buf);
-            }
-        }
-
-        if (bufs.empty()) {
-            throw std::runtime_error("failed to allocate buffer");
-        }
-
-        for (auto & buf : bufs) {
-            // indicate that this buffer contains weights
-            // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight
-            ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
-        }
-
-        ctx_bufs.emplace_back(ctx, bufs);
-    }
-
-    if (llama_supports_gpu_offload()) {
-        const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
-
-        LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
-        if (n_gpu_layers > (int) hparams.n_layer) {
-            LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__);
-        }
-
-        const int max_backend_supported_layers = hparams.n_layer + 1;
-        const int max_offloadable_layers       = hparams.n_layer + 1;
-
-        LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
-    }
-
-    // print memory requirements per buffer type
-    for (auto & buf : model.bufs) {
-        LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
-    }
-
-    // populate tensors_by_name
-    for (auto & ctx : model.ctxs) {
-        for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
-            model.tensors_by_name.emplace_back(ggml_get_name(cur), cur);
-        }
-    }
-
-    // load tensor data
-    for (auto & it : ctx_bufs) {
-        ggml_context * ctx = it.first;
-        auto & bufs = it.second;
-        if (!ml.load_all_data(ctx, bufs, use_mlock ? &model.mlock_mmaps : NULL, progress_callback, progress_callback_user_data)) {
-            return false;
-        }
-    }
-
-    if (use_mmap_buffer) {
-        for (auto & mapping : ml.mappings) {
-            model.mappings.emplace_back(std::move(mapping));
-        }
-    }
-
-    return true;
-}
-
 // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
-static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
-    model.t_start_us = ggml_time_us();
+static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) {
+    // loading time will be recalculated after the first eval, so
+    // we take page faults deferred by mmap() into consideration
+    model.t_load_us = 0;
+    time_meas tm(model.t_load_us);
+
+    model.t_start_us = tm.t_start_us;
 
     try {
-        llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
+        llama_model_loader ml(fname, splits, params.use_mmap, params.check_tensors, params.kv_overrides);
+
+        ml.print_info();
 
         model.hparams.vocab_only = params.vocab_only;
 
         try {
-            llm_load_arch(ml, model);
+            model.load_arch(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
         }
         try {
-            llm_load_hparams(ml, model);
+            model.load_hparams(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
         }
         try {
-            llm_load_vocab(ml, model);
+            model.load_vocab(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
         }
 
-        llm_load_stats(ml, model);
-        llm_load_print_meta(ml, model);
-
-        if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
-            model.hparams.n_vocab != model.vocab.id_to_token.size()) {
-            LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size());
-        }
+        model.load_stats(ml);
+        model.print_info();
 
         if (params.vocab_only) {
             LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
             return 0;
         }
 
-        if (!llm_load_tensors(
-            ml, model, params.n_gpu_layers, params.split_mode,  params.main_gpu, params.tensor_split, params.use_mlock,
-            params.progress_callback, params.progress_callback_user_data
-        )) {
+        if (!model.load_tensors(ml)) {
             return -2;
         }
     } catch (const std::exception & err) {
@@ -2579,10 +78,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
         return -1;
     }
 
-    // loading time will be recalculate after the first eval, so
-    // we take page faults deferred by mmap() into consideration
-    model.t_load_us = ggml_time_us() - model.t_start_us;
-
     return 0;
 }
 
@@ -2615,21 +110,36 @@ static struct ggml_tensor * llm_build_inp_embd(
         struct ggml_context * ctx,
        struct llama_context & lctx,
         const llama_hparams & hparams,
-         const llama_ubatch & batch,
+         const llama_ubatch & ubatch,
          struct ggml_tensor * tok_embd,
          const llm_build_cb & cb) {
     const int64_t n_embd = hparams.n_embd;
 
     struct ggml_tensor * inpL;
 
-    if (batch.token) {
-        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, batch.n_tokens);
+    if (ubatch.token) {
+        lctx.inp_tokens = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ubatch.n_tokens);
         cb(lctx.inp_tokens, "inp_tokens", -1);
         ggml_set_input(lctx.inp_tokens);
 
         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
+
+        // apply lora for embedding tokens if needed
+        for (auto & it : lctx.lora) {
+            struct llama_adapter_lora_weight * lw = it.first->get_weight(tok_embd);
+            if (lw == nullptr) {
+                continue;
+            }
+            const float adapter_scale = it.second;
+            const float scale = lw->get_scale(it.first->alpha, adapter_scale);
+            struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
+                ctx, lw->b, // non-transposed lora_b
+                ggml_get_rows(ctx, lw->a, lctx.inp_tokens)
+            ), scale);
+            inpL = ggml_add(ctx, inpL, inpL_delta);
+        }
     } else {
-        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_tokens);
+        lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
         inpL = lctx.inp_embd;
         ggml_set_input(lctx.inp_embd);
     }
@@ -2710,17 +220,16 @@ static struct ggml_tensor * llm_build_lora_mm(
           struct ggml_tensor * w,
           struct ggml_tensor * cur) {
     struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
-    for (auto & it : lctx.lora_adapters) {
-        struct llama_lora_weight * lora = it.first->get_weight(w);
-        if (lora == nullptr) {
+    for (auto & it : lctx.lora) {
+        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
+        if (lw == nullptr) {
             continue;
         }
-        const float alpha = it.first->alpha;
-        const float rank  = (float) lora->b->ne[0];
-        const float scale = alpha ? it.second * alpha / rank : it.second;
+        const float adapter_scale = it.second;
+        const float scale = lw->get_scale(it.first->alpha, adapter_scale);
         struct ggml_tensor * ab_cur = ggml_mul_mat(
-            ctx0, lora->b,
-            ggml_mul_mat(ctx0, lora->a, cur)
+            ctx0, lw->b,
+            ggml_mul_mat(ctx0, lw->a, cur)
         );
         ab_cur = ggml_scale(ctx0, ab_cur, scale);
         res = ggml_add(ctx0, res, ab_cur);
@@ -2736,17 +245,17 @@ static struct ggml_tensor * llm_build_lora_mm_id(
           struct ggml_tensor * cur, // struct ggml_tensor * b
           struct ggml_tensor * ids) {
     struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
-    for (auto & it : lctx.lora_adapters) {
-        struct llama_lora_weight * lora = it.first->get_weight(w);
-        if (lora == nullptr) {
+    for (auto & it : lctx.lora) {
+        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
+        if (lw == nullptr) {
             continue;
         }
         const float alpha = it.first->alpha;
-        const float rank  = (float) lora->b->ne[0];
+        const float rank  = (float) lw->b->ne[0];
         const float scale = alpha ? it.second * alpha / rank : it.second;
         struct ggml_tensor * ab_cur = ggml_mul_mat_id(
-            ctx0, lora->b,
-            ggml_mul_mat_id(ctx0, lora->a, cur, ids),
+            ctx0, lw->b,
+            ggml_mul_mat_id(ctx0, lw->a, cur, ids),
             ids
         );
         ab_cur = ggml_scale(ctx0, ab_cur, scale);
@@ -3239,7 +748,7 @@ static struct ggml_tensor * llm_build_copy_mask_state(
 static struct ggml_tensor * llm_build_mamba(
         struct ggml_context * ctx,
        struct llama_context & lctx,
-         const llama_ubatch & batch,
+         const llama_ubatch & ubatch,
          struct ggml_cgraph * graph,
          struct ggml_tensor * cur,
          struct ggml_tensor * state_copy,
@@ -3255,17 +764,17 @@ static struct ggml_tensor * llm_build_mamba(
     const int64_t d_inner = hparams.ssm_d_inner;
     const int64_t d_state = hparams.ssm_d_state;
     const int64_t dt_rank = hparams.ssm_dt_rank;
-    const int64_t n_seqs  = batch.n_seqs;
+    const int64_t n_seqs  = ubatch.n_seqs;
     // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers)
     const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms;
     // Use the same RMS norm as the final layer norm
     const float norm_rms_eps = hparams.f_norm_rms_eps;
 
-    const int64_t n_seq_tokens = batch.n_seq_tokens;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
 
     GGML_ASSERT(n_seqs != 0);
-    GGML_ASSERT(batch.equal_seqs);
-    GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs);
+    GGML_ASSERT(ubatch.equal_seqs);
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
 
     struct ggml_tensor * conv_states_all = kv.k_l[il];
     struct ggml_tensor * ssm_states_all  = kv.v_l[il];
@@ -3377,16 +886,20 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         const struct llama_layer * layer,
         struct ggml_tensor * cur,
         struct ggml_tensor * x_prev,
-        struct ggml_tensor ** wkv_state) {
+        struct ggml_tensor ** wkv_state,
+        size_t wkv_head_size,
+        size_t head_count_kv) {
     size_t n_embd       = cur->ne[0];
     size_t n_seq_tokens = cur->ne[1];
     size_t n_seqs       = cur->ne[2];
 
-    size_t head_size  = layer->time_mix_first->ne[0];
-    size_t head_count = layer->time_mix_first->ne[1];
+    size_t head_size  = wkv_head_size;
+    size_t head_count = n_embd / head_size;
 
     size_t n_tokens = n_seqs * n_seq_tokens;
 
+    bool is_qrwkv = layer->time_mix_first == nullptr;
+
     struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
 
     sx  = ggml_reshape_2d(ctx, sx,  n_embd, n_tokens);
@@ -3415,69 +928,64 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         xxx
     );
 
-    struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
-    struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
-    struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
-    struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
-    struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+    struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
+    if (layer->time_mix_lerp_fused) {
+        // fusing these weights makes some performance improvement
+        sx  = ggml_reshape_3d(ctx, sx,  n_embd, 1, n_tokens);
+        cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
+        xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
+        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+    } else {
+        // for backward compatibility
+        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
 
-    struct ggml_tensor * xw = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mw, layer->time_mix_lerp_w),
-            sx
-        ),
-        cur
-    );
+        xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
+        xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
+        xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
+        xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
+        xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
+    }
 
-    struct ggml_tensor * xk = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mk, layer->time_mix_lerp_k),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
+    struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk);
+    struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv);
+    if (layer->time_mix_receptance_b) {
+        r = ggml_add(ctx, r, layer->time_mix_receptance_b);
+    }
+    if (layer->time_mix_key_b) {
+        k = ggml_add(ctx, k, layer->time_mix_key_b);
+    }
+    if (layer->time_mix_value_b) {
+        v = ggml_add(ctx, v, layer->time_mix_value_b);
+    }
 
-    struct ggml_tensor * xv = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mv, layer->time_mix_lerp_v),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg);
+    if (is_qrwkv) {
+        g = ggml_sigmoid(ctx, g);
+    } else {
+        g = ggml_silu(ctx, g);
+    }
 
-    struct ggml_tensor * xr = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mr, layer->time_mix_lerp_r),
-            sx
-        ),
-        cur
-    );
+    if (head_count_kv != head_count) {
+        GGML_ASSERT(head_count % head_count_kv == 0);
+        k = ggml_reshape_4d(ctx, k, head_size, 1, head_count_kv, n_tokens);
+        v = ggml_reshape_4d(ctx, v, head_size, 1, head_count_kv, n_tokens);
+        struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens);
+        k = ggml_repeat(ctx, k, tmp);
+        v = ggml_repeat(ctx, v, tmp);
+    }
 
-    struct ggml_tensor * xg = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mg, layer->time_mix_lerp_g),
-            sx
-        ),
-        cur
-    );
-
-    struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1,         head_count, n_tokens);
-    struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk), 1,         head_size, head_count, n_tokens);
-    struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv), head_size, 1,         head_count, n_tokens);
-    struct ggml_tensor * g = ggml_silu(
-        ctx,
-        llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
-    );
+    k = ggml_reshape_3d(ctx, k, head_size, head_count, n_tokens);
+    v = ggml_reshape_3d(ctx, v, head_size, head_count, n_tokens);
+    r = ggml_reshape_3d(ctx, r, head_size, head_count, n_tokens);
 
     struct ggml_tensor * w = ggml_mul_mat(
         ctx,
@@ -3488,25 +996,35 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         )
     );
 
-    w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
+    w = ggml_add(ctx, w, layer->time_mix_decay);
     w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
-    w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
+    w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens);
 
-    k = ggml_transpose(ctx, k);
-    v = ggml_transpose(ctx, v);
-    r = ggml_transpose(ctx, r);
+    if (is_qrwkv) {
+        // k = k * (1 - w)
+        k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
+    }
 
-    struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    struct ggml_tensor * wkv_output;
+    if (!layer->time_mix_first) {
+        wkv_output = ggml_gated_linear_attn(ctx, k, v, r, w, *wkv_state, pow(head_size, -0.5f));
+    } else {
+        wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    }
     cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
     *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
 
-    // group norm with head_count groups
-    cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
-    cur = ggml_norm(ctx, cur, 64e-5f);
+    if (!is_qrwkv) {
+        // group norm with head_count groups
+        cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
+        cur = ggml_norm(ctx, cur, 64e-5f);
 
-    // Convert back to regular vectors.
-    cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
-    cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+        // Convert back to regular vectors.
+        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+        cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+    } else {
+        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+    }
 
     cur = ggml_mul(ctx, cur, g);
     cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
@@ -3670,7 +1188,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_k_shift() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         GGML_ASSERT(kv_self.size == n_ctx);
 
@@ -3720,7 +1238,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_defrag(const std::vector & moves) {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         for (const auto & move : moves) {
             for (int il = 0; il < n_layer; ++il) {
@@ -3965,7 +1483,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_llama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -4059,6 +1577,7 @@ struct llm_build_context {
 
             // feed-forward network
             if (model.layers[il].ffn_gate_inp == nullptr) {
+
                 cur = llm_build_norm(ctx0, ffn_inp, hparams,
                         model.layers[il].ffn_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -4129,8 +1648,8 @@ struct llm_build_context {
         return gf;
     }
 
-        struct ggml_cgraph * build_mllama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+    struct ggml_cgraph * build_mllama() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -4364,7 +1883,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deci() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -4525,7 +2044,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_baichuan() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -4537,7 +2056,7 @@ struct llm_build_context {
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = model.type == MODEL_7B ? build_inp_pos() : nullptr;
+        struct ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -4562,7 +2081,7 @@ struct llm_build_context {
                 cb(Vcur, "Vcur", il);
 
                 switch (model.type) {
-                    case MODEL_7B:
+                    case LLM_TYPE_7B:
                         Qcur = ggml_rope_ext(
                             ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -4574,7 +2093,7 @@ struct llm_build_context {
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         break;
-                    case MODEL_13B:
+                    case LLM_TYPE_13B:
                         Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
                         Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
                         break;
@@ -4640,7 +2159,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_xverse() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -4743,7 +2262,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_falcon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -4863,7 +2382,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_grok() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -5022,7 +2541,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_dbrx() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -5150,7 +2669,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_starcoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5254,7 +2773,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_refact() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5348,7 +2867,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bert() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5542,7 +3061,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bloom() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5643,7 +3162,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mpt() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5933,7 +3452,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6045,7 +3564,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6157,7 +3676,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2vl() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
         GGML_ASSERT(n_embd_head == hparams.n_rot);
@@ -6275,7 +3794,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2moe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -6423,7 +3942,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -6544,7 +4063,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -6577,7 +4096,7 @@ struct llm_build_context {
 
                 struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm,
-                    NULL,
+                    model.layers[il].attn_norm_b,
                     LLM_NORM_RMS, cb, il);
                 cb(attn_norm_output, "attn_norm", il);
 
@@ -6592,8 +4111,7 @@ struct llm_build_context {
                     Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
                     Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
-                }
-                else {
+                } else {
                     Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
@@ -6637,14 +4155,12 @@ struct llm_build_context {
             residual = cur;
 
             cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_norm, NULL,
+                model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
                 LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            // FF
-            // special-case: the up and gate tensors are merged into a single tensor
-            // TOOD: support into llm_build_ffn
-            {
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
                 cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         NULL,                      NULL, NULL,
@@ -6652,6 +4168,20 @@ struct llm_build_context {
                         NULL,
                         LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        cb, il);
+                cb(cur, "ffn_moe_out", il);
             }
 
             cur = ggml_add(ctx0, residual, cur);
@@ -6664,11 +4194,16 @@ struct llm_build_context {
 
         cur = llm_build_norm(ctx0, inpL, hparams,
             model.output_norm,
-            NULL,
+            model.output_norm_b,
             LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        if (model.output_b != nullptr) {
+            cb(cur, "result_output_no_bias", -1);
+            cur = ggml_add(ctx0, cur, model.output_b);
+        }
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -6782,7 +4317,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gpt2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -6887,7 +4422,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_codeshell() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -6998,7 +4533,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_orion() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7116,7 +4651,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_internlm2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7234,7 +4769,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_minicpm3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         //TODO: if the model varies, these parameters need to be read from the model
         const int64_t n_embd_base = 256;
@@ -7318,7 +4853,8 @@ struct llm_build_context {
                         ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
                 cb(k_pe, "k_pe", il);
 
-                kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
+                kv_compressed = ggml_cont(ctx0, kv_compressed);
                 kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
                         model.layers[il].attn_kv_a_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -7350,7 +4886,7 @@ struct llm_build_context {
                     0);
                 cb(v_states, "v_states", il);
 
-                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
                 q_pe = ggml_rope_ext(
                     ctx0, q_pe, inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -7359,7 +4895,7 @@ struct llm_build_context {
                 cb(q_pe, "q_pe", il);
 
                 // shared RoPE key
-                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
                 k_pe = ggml_rope_ext(
                     ctx0, k_pe, inp_pos, rope_factors,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -7443,7 +4979,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -7551,7 +5087,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -7601,9 +5137,9 @@ struct llm_build_context {
 
                 // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
                 switch (model.type) {
-                    case llm_type::MODEL_2B:
-                    case llm_type::MODEL_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
-                    case llm_type::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    case LLM_TYPE_2B:
+                    case LLM_TYPE_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
                     default: GGML_ABORT("fatal error");
                 };
                 cb(Qcur, "Qcur_scaled", il);
@@ -7687,7 +5223,7 @@ struct llm_build_context {
 
 
     struct ggml_cgraph * build_starcoder2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7806,7 +5342,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mamba() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -7861,7 +5397,7 @@ struct llm_build_context {
 
     struct ggml_cgraph * build_command_r() {
 
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8009,7 +5545,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_cohere2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8146,7 +5682,7 @@ struct llm_build_context {
     //   * removed bias
     //   * removed MoE
     struct ggml_cgraph * build_olmo() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8270,7 +5806,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_olmo2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8398,7 +5934,7 @@ struct llm_build_context {
     //   * removed bias
     //   * added q, k norm
     struct ggml_cgraph * build_olmoe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8524,7 +6060,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_openelm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8649,7 +6185,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gptneox() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -8791,7 +6327,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_arctic() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8925,7 +6461,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deepseek() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9082,7 +6618,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deepseek2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9172,7 +6708,8 @@ struct llm_build_context {
                         ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
                 cb(k_pe, "k_pe", il);
 
-                kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
+                // TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
+                kv_compressed = ggml_cont(ctx0, kv_compressed);
                 kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
                         model.layers[il].attn_kv_a_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -9204,7 +6741,7 @@ struct llm_build_context {
                     0);
                 cb(v_states, "v_states", il);
 
-                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
                 q_pe = ggml_rope_ext(
                     ctx0, q_pe, inp_pos, nullptr,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9213,7 +6750,7 @@ struct llm_build_context {
                 cb(q_pe, "q_pe", il);
 
                 // shared RoPE key
-                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
+                k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend used to not support non-cont. RoPE, investigate removing this
                 k_pe = ggml_rope_ext(
                     ctx0, k_pe, inp_pos, nullptr,
                     n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -9312,7 +6849,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bitnet() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9463,7 +7000,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_t5_enc() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9595,7 +7132,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_t5_dec() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9800,7 +7337,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_jais() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9892,7 +7429,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_chatglm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9923,17 +7460,30 @@ struct llm_build_context {
                 struct ggml_tensor * Qcur = nullptr;
                 struct ggml_tensor * Kcur = nullptr;
                 struct ggml_tensor * Vcur = nullptr;
-
-                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
-                cb(cur, "wqkv", il);
-
-                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
-                cb(cur, "bqkv", il);
-
-                Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
-                Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
-                Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
-
+                if (model.layers[il].wqkv == nullptr) {
+                    Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
+                    if (model.layers[il].bq) {
+                        Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    }
+                    Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
+                    if (model.layers[il].bk) {
+                        Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    }
+                    Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
+                    if (model.layers[il].bv) {
+                        Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    }
+                } else {
+                    cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
+                    cb(cur, "wqkv", il);
+                    if (model.layers[il].bqkv) {
+                        cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                        cb(cur, "bqkv", il);
+                    }
+                    Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                    Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                    Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+                }
                 cb(Qcur, "Qcur", il);
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
@@ -10006,7 +7556,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_nemotron() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10127,7 +7677,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_exaone() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10254,7 +7804,7 @@ struct llm_build_context {
     }
 
     ggml_cgraph * build_rwkv6() {
-        ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // Token shift state dimensions should be 2 * n_emb
         GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
@@ -10299,7 +7849,7 @@ struct llm_build_context {
                 1
             );
 
-            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
+            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size));
             ggml_build_forward_expand(gf, cur);
             ggml_build_forward_expand(
                 gf,
@@ -10366,6 +7916,114 @@ struct llm_build_context {
         return gf;
     }
 
+    // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
+    ggml_cgraph * build_rwkv6qwen2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+
+        const int64_t n_seqs = ubatch.n_seqs;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+        struct ggml_tensor * state_copy = build_inp_s_copy();
+        struct ggml_tensor * state_mask = build_inp_s_mask();
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            // (ab)using the KV cache to store the states
+            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.k_l[il], state_copy, state_mask,
+                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
+            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.v_l[il], state_copy, state_mask,
+                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
+
+            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
+
+            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
+            struct ggml_tensor * x_prev = ggml_concat(
+                ctx0,
+                token_shift,
+                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
+                1
+            );
+
+            struct ggml_tensor * last_norm_att = ggml_view_3d(ctx0, x_norm_att, n_embd, 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(x_norm_att));
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    ggml_view_1d(ctx0, last_norm_att, n_embd * n_seqs, 0),
+                    ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
+                )
+            );
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv()));
+            ggml_build_forward_expand(gf, ffn_inp);
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_states,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self.v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+                    )
+                )
+            );
+
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // ref: https://github.com/facebookresearch/chameleon
     // based on the original build_llama() function, changes:
     //   * qk-norm
@@ -10373,7 +8031,7 @@ struct llm_build_context {
     //   * removed bias
     //   * removed MoE
     struct ggml_cgraph * build_chameleon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10538,14 +8196,12 @@ struct llm_build_context {
         cb(img_logits, "img_logits", -1);
         cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
         cb(cur, "result_output", -1);
-
         ggml_build_forward_expand(gf, cur);
-
         return gf;
-    }
+   }
 
-    ggml_cgraph * build_solar() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+   ggml_cgraph * build_solar() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10681,23 +8337,19 @@ struct llm_build_context {
         }
 
         cur = inpL;
-
         cur = llm_build_norm(ctx0, cur, hparams,
                 model.output_norm, NULL,
                 LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
-
         // lm_head
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
-
         ggml_build_forward_expand(gf, cur);
-
         return gf;
     }
 
     struct ggml_cgraph * build_wavtokenizer_dec() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -10906,12 +8558,12 @@ static struct ggml_cgraph * llama_build_graph(
 
         // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
         // FIXME: fix in ggml_backend_sched
-        const bool full_offload = lctx.model.n_gpu_layers > (int)lctx.model.hparams.n_layer;
+        const bool full_offload = lctx.model.params.n_gpu_layers > (int) lctx.model.hparams.n_layer;
         if (ubatch.n_tokens < 32 || full_offload) {
             if (il != -1 && strcmp(name, "norm") == 0) {
-                const auto & dev_layer = lctx.model.dev_layer.at(il);
+                const auto & dev_layer = lctx.model.dev_layer(il);
                 for (auto & backend : lctx.backends) {
-                    if (ggml_backend_get_device(backend.get()) == dev_layer.dev) {
+                    if (ggml_backend_get_device(backend.get()) == dev_layer) {
                         if (ggml_backend_supports_op(backend.get(), cur)) {
                             ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, backend.get());
                         }
@@ -11003,6 +8655,7 @@ static struct ggml_cgraph * llama_build_graph(
                 result = llm.build_phi2();
             } break;
         case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
             {
                 result = llm.build_phi3();
             } break;
@@ -11130,6 +8783,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_rwkv6();
             } break;
+        case LLM_ARCH_RWKV6QWEN2:
+            {
+                result = llm.build_rwkv6qwen2();
+            } break;
         case LLM_ARCH_CHAMELEON:
             {
                 result = llm.build_chameleon();
@@ -11183,73 +8840,33 @@ static enum ggml_status llama_graph_compute(
     return status;
 }
 
-// decode a batch of tokens by evaluating the transformer
-// in case of unsuccessful decoding (error or warning),
-// the kv_cache state will be returned to its original state
-// (for non-recurrent models) or cleaned (for recurrent models)
-//
-//   - lctx:      llama context
-//   - batch:     batch to evaluate
-//
-// return 0 on success
-// return positive int on warning
-// return negative int on error
-//
-static int llama_decode_internal(
-         llama_context & lctx,
-           llama_batch   inp_batch) {
-
-    lctx.is_encoding = false;
-
-    if (inp_batch.n_tokens == 0) {
-        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
-        return -1;
-    }
-
-    // temporary allocate memory for the input batch if needed
-    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
-
-    const llama_batch & batch = batch_allocr.batch;
-    const uint32_t n_tokens_all = batch.n_tokens;
-
+static int llama_prepare_sbatch(
+        llama_context     & lctx,
+        const llama_batch & batch,
+        uint32_t          & n_outputs) {
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
 
-    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
+    const uint32_t n_tokens_all = batch.n_tokens;
+    const  int64_t n_embd       = hparams.n_embd;
 
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
+    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
     if (batch.token) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+            if (batch.token[i] < 0 || uint32_t(batch.token[i]) >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
         }
     }
-
     GGML_ASSERT(n_tokens_all <= cparams.n_batch);
-
     GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
 
-    if (lctx.t_compute_start_us == 0) {
-        lctx.t_compute_start_us = ggml_time_us();
-    }
     lctx.n_queued_tokens += n_tokens_all;
-
-    auto & kv_self = lctx.kv_self;
-    llama_kv_slot_restorer kv_slot_restorer(kv_self);
-
-    const int64_t n_embd  = hparams.n_embd;
-    const int64_t n_vocab = hparams.n_vocab;
-
-    uint32_t n_outputs = 0;
-    uint32_t n_outputs_prev = 0;
-
-    const auto n_ubatch = cparams.n_ubatch;
-
-    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
-    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
-
     lctx.embd_seq.clear();
 
     // count outputs
@@ -11265,7 +8882,7 @@ static int llama_decode_internal(
     }
 
     lctx.sbatch.from_batch(batch, batch.n_embd,
-        /* simple_split */ !kv_self.recurrent,
+        /* simple_split */ !lctx.kv_self.recurrent,
         /* logits_all   */ n_outputs == n_tokens_all);
 
     // reserve output buffer
@@ -11274,75 +8891,152 @@ static int llama_decode_internal(
         return -2;
     };
 
+    return 0;
+}
+
+static int llama_prepare_ubatch(
+        llama_context          & lctx,
+        llama_kv_slot_restorer & kv_slot_restorer,
+        llama_ubatch           & ubatch,
+        const uint32_t           n_outputs,
+        const uint32_t           n_tokens_all) {
+    GGML_ASSERT(lctx.sbatch.n_tokens > 0);
+
+    auto       & kv_self = lctx.kv_self;
+    const auto & cparams = lctx.cparams;
+    const auto & hparams = lctx.model.hparams;
+
+    // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
+    const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
+
+    if (lctx.kv_self.recurrent) {
+        if (embd_pooled) {
+            // Pooled embeddings cannot be split across ubatches (yet)
+            ubatch = lctx.sbatch.split_seq(cparams.n_ubatch);
+        } else {
+            // recurrent model architectures are easier to implement
+            // with equal-length sequences
+            ubatch = lctx.sbatch.split_equal(cparams.n_ubatch);
+        }
+    } else {
+        ubatch = lctx.sbatch.split_simple(cparams.n_ubatch);
+    }
+
+    // count the outputs in this u_batch
+    {
+        int32_t n_outputs_new = 0;
+
+        if (n_outputs == n_tokens_all) {
+            n_outputs_new = ubatch.n_tokens;
+        } else {
+            GGML_ASSERT(ubatch.output);
+            for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
+                n_outputs_new += int32_t(ubatch.output[i] != 0);
+            }
+        }
+
+        // needs to happen before the graph is built
+        lctx.n_outputs = n_outputs_new;
+    }
+
+    // non-causal masks do not use the KV cache
+    if (hparams.causal_attn) {
+        llama_kv_cache_update(&lctx);
+
+        // if we have enough unused cells before the current head ->
+        //   better to start searching from the beginning of the cache, hoping to fill it
+        if (kv_self.head > kv_self.used + 2*ubatch.n_tokens) {
+            kv_self.head = 0;
+        }
+
+        auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
+        if (!slot) {
+            llama_kv_cache_defrag(kv_self);
+            llama_kv_cache_update(&lctx);
+            slot = llama_kv_cache_find_slot(kv_self, ubatch);
+        }
+        if (!slot) {
+            return 1;
+        }
+        kv_slot_restorer.save(slot);
+
+        if (!kv_self.recurrent) {
+            // a heuristic, to avoid attending the full cache if it is not yet utilized
+            // after enough generations, the benefit from this heuristic disappears
+            // if we start defragmenting the cache, the benefit from this will be more important
+            const uint32_t pad = llama_kv_cache_get_padding(cparams);
+            kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
+            //kv_self.n = llama_kv_cache_cell_max(kv_self);
+        }
+    }
+
+    return 0;
+}
+
+// decode a batch of tokens by evaluating the transformer
+// in case of unsuccessful decoding (error or warning),
+// the kv_cache state will be returned to its original state
+// (for non-recurrent models) or cleaned (for recurrent models)
+//
+//   - lctx:      llama context
+//   - inp_batch: batch to evaluate
+//
+// return 0 on success
+// return positive int on warning
+// return negative int on error
+//
+static int llama_decode_impl(
+         llama_context & lctx,
+           llama_batch   inp_batch) {
+
+    lctx.is_encoding = false;
+
+    if (inp_batch.n_tokens == 0) {
+        LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
+        return -1;
+    }
+
+    // temporarily allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1);
+    const llama_batch & batch = batch_allocr.batch;
+
+    const auto & model   = lctx.model;
+    const auto & hparams = model.hparams;
+    const auto & cparams = lctx.cparams;
+
+    if (lctx.t_compute_start_us == 0) {
+        lctx.t_compute_start_us = ggml_time_us();
+    }
+    auto & kv_self = lctx.kv_self;
+    llama_kv_slot_restorer kv_slot_restorer(kv_self);
+
+    const int64_t n_embd  = hparams.n_embd;
+    const int64_t n_vocab = hparams.n_vocab;
+
+    uint32_t n_outputs = 0;
+    uint32_t n_outputs_prev = 0;
+
+    {
+        const int ret = llama_prepare_sbatch(lctx, batch, n_outputs);
+        if (ret != 0) {
+            return ret;
+        }
+    }
+
     while (lctx.sbatch.n_tokens > 0) {
         llama_ubatch ubatch;
-        if (kv_self.recurrent) {
-            if (embd_pooled) {
-                // Pooled embeddings cannot be split across ubatches (yet)
-                ubatch = lctx.sbatch.split_seq(n_ubatch);
-            } else {
-                // recurrent model architectures are easier to implement
-                // with equal-length sequences
-                ubatch = lctx.sbatch.split_equal(n_ubatch);
-            }
-        } else {
-            ubatch = lctx.sbatch.split_simple(n_ubatch);
-        }
-        const uint32_t n_tokens = ubatch.n_tokens;
-
-        // count the outputs in this u_batch
         {
-            int32_t n_outputs_new = 0;
-
-            if (n_outputs == n_tokens_all) {
-                n_outputs_new = n_tokens;
-            } else {
-                GGML_ASSERT(ubatch.output);
-                for (uint32_t i = 0; i < n_tokens; i++) {
-                    n_outputs_new += (int32_t) (ubatch.output[i] != 0);
-                }
+            const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens);
+            if (ret != 0) {
+                return ret;
             }
-
-            // needs to happen before the graph is built
-            lctx.n_outputs = n_outputs_new;
         }
 
-        int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
-        ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch;
+        const int         n_threads  = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
+        ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool   : lctx.threadpool_batch;
 
         GGML_ASSERT(n_threads > 0);
 
-        // non-causal masks do not use the KV cache
-        if (hparams.causal_attn) {
-            llama_kv_cache_update(&lctx);
-
-            // if we have enough unused cells before the current head ->
-            //   better to start searching from the beginning of the cache, hoping to fill it
-            if (kv_self.head > kv_self.used + 2*n_tokens) {
-                kv_self.head = 0;
-            }
-
-            auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-            if (!slot) {
-                llama_kv_cache_defrag(kv_self);
-                llama_kv_cache_update(&lctx);
-                slot = llama_kv_cache_find_slot(kv_self, ubatch);
-            }
-            if (!slot) {
-                return 1;
-            }
-            kv_slot_restorer.save(slot);
-
-            if (!kv_self.recurrent) {
-                // a heuristic, to avoid attending the full cache if it is not yet utilized
-                // after enough generations, the benefit from this heuristic disappears
-                // if we start defragmenting the cache, the benefit from this will be more important
-                const uint32_t pad = llama_kv_cache_get_padding(cparams);
-                kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self), pad)));
-                //kv_self.n = llama_kv_cache_cell_max(kv_self);
-            }
-        }
-
         //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
 
         ggml_backend_sched_reset(lctx.sched.get());
@@ -11397,7 +9091,7 @@ static int llama_decode_internal(
 
         // update the kv ring buffer
         {
-            kv_self.head += n_tokens;
+            kv_self.head += ubatch.n_tokens;
 
             // Ensure kv cache head points to a valid index.
             if (kv_self.head >= kv_self.size) {
@@ -11510,12 +9204,14 @@ static int llama_decode_internal(
     //llama_synchronize(&lctx);
 
     // decide if we need to defrag the kv cache
-    if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
-        const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
+    if (cparams.causal_attn && cparams.defrag_thold > 0.0f) {
+        // - do not defrag small contexts (i.e. < 2048 tokens)
+        // - count the padding towards the number of used tokens
+        const float fragmentation = kv_self.n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self.used + llama_kv_cache_get_padding(cparams))/float(kv_self.n)) : 0.0f;
 
         // queue defragmentation for next llama_kv_cache_update
         if (fragmentation > cparams.defrag_thold) {
-            //LLAMA_LOG_INFO("fragmentation: %.2f\n", fragmentation);
+            LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
 
             llama_kv_cache_defrag(kv_self);
         }
@@ -11537,7 +9233,7 @@ static int llama_decode_internal(
 // return positive int on warning
 // return negative int on error
 //
-static int llama_encode_internal(
+static int llama_encode_impl(
          llama_context & lctx,
            llama_batch   inp_batch) {
 
@@ -11562,7 +9258,7 @@ static int llama_encode_internal(
 
     if (batch.token) {
         for (uint32_t i = 0; i < n_tokens; ++i) {
-            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
@@ -11719,7 +9415,7 @@ static int llama_encode_internal(
 }
 
 // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
-static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
     auto & kv_self = lctx.kv_self;
 
     const auto & hparams = lctx.model.hparams;
@@ -11739,9 +9435,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     // each move requires 6*n_layer tensors (see build_defrag)
     //   - source view, destination view, copy operation
     //   - x2 for keys and values
-    //const uint32_t max_moves = llama_model_max_nodes(model)/(6*n_layer);
+    //const uint32_t max_moves = model.max_nodes()/(6*n_layer);
     // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (llama_model_max_nodes(lctx.model) - 2*n_layer)/(6*n_layer);
+    const uint32_t max_moves = (lctx.model.max_nodes() - 2*n_layer)/(6*n_layer);
 
     // determine which KV cells to move where
     //
@@ -11934,7 +9630,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
     //LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
 }
 
-static void llama_kv_cache_update_internal(struct llama_context & lctx) {
+static void llama_kv_cache_update_impl(struct llama_context & lctx) {
     bool need_reserve = false;
 
     if (lctx.kv_self.has_shift) {
@@ -11970,7 +9666,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
 
     // defragment the KV cache if needed
     if (lctx.kv_self.do_defrag) {
-        llama_kv_cache_defrag_internal(lctx);
+        llama_kv_cache_defrag_impl(lctx);
 
         need_reserve = true;
 
@@ -11983,7 +9679,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
         // build worst-case graph
         uint32_t n_seqs = 1; // TODO: worst-case number of sequences
         uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
-        llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+        llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
         llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
         ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
 
@@ -11995,45 +9691,38 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
     }
 }
 
-int32_t llama_lora_adapter_set(
+int32_t llama_set_adapter_lora(
             struct llama_context * ctx,
-            struct llama_lora_adapter * adapter,
+            struct llama_adapter_lora * adapter,
             float scale) {
-    if (ctx->cparams.flash_attn) {
-        LLAMA_LOG_ERROR("%s: flash_attn is not compatible with LoRA\n", __func__);
-        return -1;
-    }
-
-    ctx->lora_adapters[adapter] = scale;
-
+    ctx->lora[adapter] = scale;
     return 0;
 }
 
-int32_t llama_lora_adapter_remove(
+int32_t llama_rm_adapter_lora(
             struct llama_context * ctx,
-            struct llama_lora_adapter * adapter) {
-    auto pos = ctx->lora_adapters.find(adapter);
-    if (pos != ctx->lora_adapters.end()) {
-        ctx->lora_adapters.erase(pos);
+            struct llama_adapter_lora * adapter) {
+    auto pos = ctx->lora.find(adapter);
+    if (pos != ctx->lora.end()) {
+        ctx->lora.erase(pos);
         return 0;
     }
 
     return -1;
 }
 
-void llama_lora_adapter_clear(struct llama_context * ctx) {
-    ctx->lora_adapters.clear();
+void llama_clear_adapter_lora(struct llama_context * ctx) {
+    ctx->lora.clear();
 }
 
-// TODO: tmp
-int32_t llama_control_vector_apply(
-        struct llama_context * lctx,
+int32_t llama_apply_adapter_cvec(
+        struct llama_context * ctx,
                  const float * data,
                       size_t   len,
                      int32_t   n_embd,
                      int32_t   il_start,
                      int32_t   il_end) {
-    return llama_control_vector_apply(lctx->cvec, lctx->model, data, len, n_embd, il_start, il_end);
+    return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
 }
 
 //
@@ -12134,13 +9823,12 @@ int64_t llama_time_us(void) {
     return ggml_time_us();
 }
 
-struct llama_model * llama_load_model_from_file(
-        const char * path_model,
+static struct llama_model * llama_model_load_from_file_impl(
+        const std::string & path_model,
+        std::vector & splits,
         struct llama_model_params params) {
     ggml_time_init();
 
-    llama_model * model = new llama_model;
-
     unsigned cur_percentage = 0;
     if (params.progress_callback == NULL) {
         params.progress_callback_user_data = &cur_percentage;
@@ -12158,46 +9846,7 @@ struct llama_model * llama_load_model_from_file(
         };
     }
 
-    if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
-        // split the servers set them into model->rpc_servers
-        std::string servers(params.rpc_servers);
-        size_t pos = 0;
-        while ((pos = servers.find(',')) != std::string::npos) {
-            std::string server = servers.substr(0, pos);
-            model->rpc_servers.push_back(server);
-            servers.erase(0, pos + 1);
-        }
-        model->rpc_servers.push_back(servers);
-    }
-
-    // add RPC devices
-    if (!model->rpc_servers.empty()) {
-        ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
-        if (!rpc_reg) {
-            LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
-            llama_free_model(model);
-            return nullptr;
-        }
-
-        typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
-        ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
-        if (!ggml_backend_rpc_add_device_fn) {
-            LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
-            llama_free_model(model);
-            return nullptr;
-        }
-
-        for (const std::string & server : model->rpc_servers) {
-            ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
-            if (dev) {
-                model->devices.push_back(dev);
-            } else {
-                LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
-                llama_free_model(model);
-                return nullptr;
-            }
-        }
-    }
+    llama_model * model = new llama_model(params);
 
     // create list of devices to use with this model
     if (params.devices) {
@@ -12205,6 +9854,7 @@ struct llama_model * llama_load_model_from_file(
             model->devices.push_back(*dev);
         }
     } else {
+        std::vector rpc_servers;
         // use all available devices
         for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
             ggml_backend_dev_t dev = ggml_backend_dev_get(i);
@@ -12215,17 +9865,26 @@ struct llama_model * llama_load_model_from_file(
                     break;
 
                 case GGML_BACKEND_DEVICE_TYPE_GPU:
-                    model->devices.push_back(dev);
+                    ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
+                    if (ggml_backend_reg_name(reg) == std::string("RPC")) {
+                        rpc_servers.push_back(dev);
+                    } else {
+                        model->devices.push_back(dev);
+                    }
                     break;
             }
         }
+        // add RPC servers at the front of the list
+        if (!rpc_servers.empty()) {
+            model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end());
+        }
     }
 
     // if using single GPU mode, remove all except the main GPU
     if (params.split_mode == LLAMA_SPLIT_MODE_NONE) {
         if (params.main_gpu < 0 || params.main_gpu >= (int)model->devices.size()) {
             LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %d)\n", __func__, params.main_gpu, (int)model->devices.size());
-            llama_free_model(model);
+            llama_model_free(model);
             return nullptr;
         }
         ggml_backend_dev_t main_gpu = model->devices[params.main_gpu];
@@ -12239,7 +9898,7 @@ struct llama_model * llama_load_model_from_file(
         LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
     }
 
-    int status = llama_model_load(path_model, *model, params);
+    const int status = llama_model_load(path_model, splits, *model, params);
     GGML_ASSERT(status <= 0);
     if (status < 0) {
         if (status == -1) {
@@ -12248,14 +9907,43 @@ struct llama_model * llama_load_model_from_file(
             LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
         }
 
-        llama_free_model(model);
+        llama_model_free(model);
         return nullptr;
     }
 
     return model;
 }
 
-struct llama_context * llama_new_context_with_model(
+// deprecated
+struct llama_model * llama_load_model_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
+    return llama_model_load_from_file(path_model, params);
+}
+
+struct llama_model * llama_model_load_from_file(
+        const char * path_model,
+        struct llama_model_params params) {
+    std::vector splits = {};
+    return llama_model_load_from_file_impl(path_model, splits, params);
+}
+
+struct llama_model * llama_model_load_from_splits(
+        const char ** paths,
+        size_t n_paths,
+        struct llama_model_params params) {
+    std::vector splits;
+    if (n_paths == 0) {
+        LLAMA_LOG_ERROR("%s: list of splits is empty\n", __func__);
+        return nullptr;
+    }
+    for (size_t i = 0; i < n_paths; ++i) {
+        splits.push_back(paths[i]);
+    }
+    return llama_model_load_from_file_impl(splits.front(), splits, params);
+}
+
+struct llama_context * llama_init_from_model(
                  struct llama_model * model,
         struct llama_context_params   params) {
 
@@ -12513,7 +10201,7 @@ struct llama_context * llama_new_context_with_model(
                 backend_ptrs.push_back(backend.get());
             }
 
-            const size_t max_nodes = llama_model_max_nodes(*model);
+            const size_t max_nodes = model->max_nodes();
 
             // buffer used to store the computation graph and the tensor meta data
             ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
@@ -12521,9 +10209,9 @@ struct llama_context * llama_new_context_with_model(
             // TODO: move these checks to ggml_backend_sched
             // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
             bool pipeline_parallel =
-                llama_get_device_count(*model) > 1 &&
-                model->n_gpu_layers > (int)model->hparams.n_layer &&
-                model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
+                model->n_devices() > 1 &&
+                model->params.n_gpu_layers > (int)model->hparams.n_layer &&
+                model->params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
                 params.offload_kqv;
 
             // pipeline parallelism requires support for async compute and events in all devices
@@ -12554,7 +10242,7 @@ struct llama_context * llama_new_context_with_model(
             // initialize scheduler with the worst-case graph
             uint32_t n_seqs = 1; // TODO: worst-case number of sequences
             uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-            llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+            llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 
             llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
             ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
@@ -12606,6 +10294,12 @@ struct llama_context * llama_new_context_with_model(
     return ctx;
 }
 
+struct llama_context * llama_new_context_with_model(
+                 struct llama_model * model,
+        struct llama_context_params   params) {
+    return llama_init_from_model(model, params);
+}
+
 //
 // kv cache
 //
@@ -12672,7 +10366,7 @@ void llama_kv_cache_defrag(struct llama_context * ctx) {
 }
 
 void llama_kv_cache_update(struct llama_context * ctx) {
-    llama_kv_cache_update_internal(*ctx);
+    llama_kv_cache_update_impl(*ctx);
 }
 
 bool llama_kv_cache_can_shift(struct llama_context * ctx) {
@@ -12684,7 +10378,7 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) {
 int32_t llama_encode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
-    const int ret = llama_encode_internal(*ctx, batch);
+    const int ret = llama_encode_impl(*ctx, batch);
     if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
     }
@@ -12695,7 +10389,7 @@ int32_t llama_encode(
 int32_t llama_decode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
-    const int ret = llama_decode_internal(*ctx, batch);
+    const int ret = llama_decode_impl(*ctx, batch);
     if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
@@ -12703,166 +10397,18 @@ int32_t llama_decode(
     return ret;
 }
 
-//
-// vocab
-//
-
-// TODO: tmp bridges below until `struct llama_vocab` is exposed through the public API
-
-const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
-    return llama_token_get_text_impl(model->vocab, token);
-}
-
-float llama_token_get_score(const struct llama_model * model, llama_token token) {
-    return llama_token_get_score_impl(model->vocab, token);
-}
-
-enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
-    return llama_token_get_attr_impl(model->vocab, token);
-}
-
-bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
-    return llama_token_is_eog_impl(model->vocab, token);
-}
-
-bool llama_token_is_control(const struct llama_model * model, llama_token token) {
-    return llama_token_is_control_impl(model->vocab, token);
-}
-
-llama_token llama_token_bos(const struct llama_model * model) {
-    return llama_token_bos_impl(model->vocab);
-}
-
-llama_token llama_token_eos(const struct llama_model * model) {
-    return llama_token_eos_impl(model->vocab);
-}
-
-llama_token llama_token_eot(const struct llama_model * model) {
-    return llama_token_eot_impl(model->vocab);
-}
-
-llama_token llama_token_cls(const struct llama_model * model) {
-    return llama_token_cls_impl(model->vocab);
-}
-
-llama_token llama_token_sep(const struct llama_model * model) {
-    return llama_token_sep_impl(model->vocab);
-}
-
-llama_token llama_token_nl (const struct llama_model * model) {
-    return llama_token_nl_impl(model->vocab);
-}
-
-llama_token llama_token_pad(const struct llama_model * model) {
-    return llama_token_pad_impl(model->vocab);
-}
-
-bool llama_add_bos_token(const struct llama_model * model) {
-    return llama_add_bos_token_impl(model->vocab);
-}
-
-bool llama_add_eos_token(const struct llama_model * model) {
-    return llama_add_eos_token_impl(model->vocab);
-}
-
-llama_token llama_token_prefix(const struct llama_model * model) {
-    return llama_token_prefix_impl(model->vocab);
-}
-
-llama_token llama_token_middle(const struct llama_model * model) {
-    return llama_token_middle_impl(model->vocab);
-}
-
-llama_token llama_token_suffix(const struct llama_model * model) {
-    return llama_token_suffix_impl(model->vocab);
-}
-
-llama_token llama_token_fim_pre(const struct llama_model * model) {
-    return llama_token_fim_pre_impl(model->vocab);
-}
-
-llama_token llama_token_fim_suf(const struct llama_model * model) {
-    return llama_token_fim_suf_impl(model->vocab);
-}
-
-llama_token llama_token_fim_mid(const struct llama_model * model) {
-    return llama_token_fim_mid_impl(model->vocab);
-}
-
-llama_token llama_token_fim_pad(const struct llama_model * model) {
-    return llama_token_fim_pad_impl(model->vocab);
-}
-
-llama_token llama_token_fim_rep(const struct llama_model * model) {
-    return llama_token_fim_rep_impl(model->vocab);
-}
-
-llama_token llama_token_fim_sep(const struct llama_model * model) {
-    return llama_token_fim_sep_impl(model->vocab);
-}
-
-//
-// tokenization
-//
-
-int32_t llama_tokenize(
-    const struct llama_model * model,
-                  const char * text,
-                     int32_t   text_len,
-                 llama_token * tokens,
-                     int32_t   n_tokens_max,
-                        bool   add_special,
-                        bool   parse_special) {
-    return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
-}
-
-int32_t llama_token_to_piece(
-    const struct llama_model * model,
-                 llama_token   token,
-                        char * buf,
-                     int32_t   length,
-                     int32_t   lstrip,
-                        bool   special) {
-    return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
-}
-
-int32_t llama_detokenize(
-    const struct llama_model * model,
-           const llama_token * tokens,
-                     int32_t   n_tokens,
-                        char * text,
-                     int32_t   text_len_max,
-                        bool   remove_special,
-                        bool   unparse_special) {
-    return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
-}
-
 //
 // chat templates
 //
 
 int32_t llama_chat_apply_template(
-                const struct llama_model * model,
                               const char * tmpl,
          const struct llama_chat_message * chat,
                                   size_t   n_msg,
                                     bool   add_ass,
                                     char * buf,
                                  int32_t   length) {
-    std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
-    if (tmpl == nullptr) {
-        GGML_ASSERT(model != nullptr);
-
-        // load template from model, if available
-        const auto & it = model->gguf_kv.find("tokenizer.chat_template");
-        if (it != model->gguf_kv.end() && it->second.size() > 0) {
-            curr_tmpl = it->second;
-        }
-        else {
-            // worst case: there is no information about template, we will use chatml by default
-            curr_tmpl = "chatml";  // see llm_chat_apply_template
-        }
-    }
+    const std::string curr_tmpl(tmpl == nullptr ? "chatml" : tmpl);
 
     // format the chat to string
     std::vector chat_vec;
@@ -12886,23 +10432,6 @@ int32_t llama_chat_apply_template(
     return res;
 }
 
-//
-// sampling
-//
-
-// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
-struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
-    return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
-}
-
-struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
-    return llama_sampler_init_infill_impl(model->vocab);
-}
-
-struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
-    return llama_sampler_init_dry_impl(model->vocab, llama_n_ctx_train(model), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
-}
-
 //
 // model split
 //
@@ -12915,16 +10444,16 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix,
     return 0;
 }
 
-int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
+int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
     std::string str_split_path(split_path);
     char postfix[32];
     snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
     std::string str_postfix(postfix);
 
-    // check if dest ends with postfix
+    // check if split_prefix ends with postfix
     int size_prefix = str_split_path.size() - str_postfix.size();
     if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
-        snprintf(dest, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
+        snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
         return size_prefix;
     }
 
@@ -12933,6 +10462,8 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int
 
 const char * llama_print_system_info(void) {
     static std::string s;
+    s.clear(); // Clear the string, since it's static, otherwise it will accumulate data from previous calls.
+
 
     for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
         auto * reg = ggml_backend_reg_get(i);
diff --git a/llama/llama.cpp/src/unicode.cpp b/llama/llama.cpp/src/unicode.cpp
index 6155da80..9dd53b9a 100644
--- a/llama/llama.cpp/src/unicode.cpp
+++ b/llama/llama.cpp/src/unicode.cpp
@@ -12,18 +12,17 @@
 
 #include 
 #include 
+#include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
 #include 
 #include 
-#include 
-#include 
 
 size_t unicode_len_utf8(char src) {
     const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
@@ -641,7 +640,14 @@ std::vector unicode_cpts_from_utf8(const std::string & utf8) {
     result.reserve(utf8.size());
     size_t offset = 0;
     while (offset < utf8.size()) {
-        result.push_back(unicode_cpt_from_utf8(utf8, offset));
+        try {
+            result.push_back(unicode_cpt_from_utf8(utf8, offset));
+        }
+        catch (const std::invalid_argument & /*ex*/) {
+            // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
+            ++offset;
+            result.emplace_back(0xFFFD); // replacement character
+        }
     }
     return result;
 }
@@ -724,7 +730,7 @@ std::vector unicode_regex_split(const std::string & text, const std
     const auto cpts = unicode_cpts_from_utf8(text);
 
     // generate a "collapsed" representation of the text, where all codepoints are replaced by a single byte
-    // ref: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2081479935
+    // ref: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2081479935
     std::string text_collapsed;
     if (need_collapse) {
         // collapse all unicode categories
diff --git a/llama/llama.go b/llama/llama.go
index a20f2357..6eed3d47 100644
--- a/llama/llama.go
+++ b/llama/llama.go
@@ -14,6 +14,7 @@ package llama
 #include "llama.h"
 #include "clip.h"
 #include "llava.h"
+#include "gguf.h"
 
 #include "mllama.h"
 #include "sampling_ext.h"
@@ -293,29 +294,29 @@ func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
 }
 
 func (m *Model) NumVocab() int {
-	return int(C.llama_n_vocab(m.c))
+	return int(C.llama_n_vocab(m.Vocab()))
 }
 
 func (m *Model) TokenIsEog(token int) bool {
-	return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
+	return bool(C.llama_token_is_eog(m.Vocab(), C.llama_token(token)))
 }
 
 func (m *Model) AddBOSToken() bool {
-	return bool(C.llama_add_bos_token(m.c))
+	return bool(C.llama_add_bos_token(m.Vocab()))
 }
 
 func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
 	cLoraPath := C.CString(loraPath)
 	defer C.free(unsafe.Pointer(cLoraPath))
 
-	loraAdapter := C.llama_lora_adapter_init(m.c, cLoraPath)
+	loraAdapter := C.llama_adapter_lora_init(m.c, cLoraPath)
 	if loraAdapter == nil {
 		return errors.New("unable to load lora")
 	}
 
 	err := -1
 	if loraAdapter != nil {
-		err = int(C.llama_lora_adapter_set(context.c, loraAdapter, C.float(scale)))
+		err = int(C.llama_set_adapter_lora(context.c, loraAdapter, C.float(scale)))
 	}
 	if err != 0 {
 		return errors.New("error applying lora from file")
@@ -324,6 +325,10 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float
 	return nil
 }
 
+func (m *Model) Vocab() *C.struct_llama_vocab {
+	return C.llama_model_get_vocab(m.c)
+}
+
 type Batch struct {
 	c         C.struct_llama_batch
 	batchSize int
@@ -414,7 +419,7 @@ func (m *Model) TokenToPiece(token int) string {
 	tokenLen := 12
 	buf := make([]byte, tokenLen)
 	tokenLen = int(C.llama_token_to_piece(
-		m.c,
+		m.Vocab(),
 		C.int32_t(token),
 		(*C.char)(unsafe.Pointer(&buf[0])),
 		C.int32_t(tokenLen),
@@ -426,7 +431,7 @@ func (m *Model) TokenToPiece(token int) string {
 
 		buf = make([]byte, tokenLen)
 		C.llama_token_to_piece(
-			m.c,
+			m.Vocab(),
 			C.int32_t(token),
 			(*C.char)(unsafe.Pointer(&buf[0])),
 			C.int32_t(tokenLen),
@@ -444,7 +449,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
 	defer C.free(unsafe.Pointer(cText))
 
 	result := C.llama_tokenize(
-		m.c,
+		m.Vocab(),
 		cText,
 		C.int32_t(len(text)),
 		&cTokens[0],
@@ -458,7 +463,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
 		maxTokens = int(-result)
 		cTokens = make([]C.llama_token, maxTokens)
 		result = C.llama_tokenize(
-			m.c,
+			m.Vocab(),
 			cText,
 			C.int32_t(len(text)),
 			&cTokens[0],
diff --git a/llama/mllama.cpp b/llama/mllama.cpp
index 4e84c60a..1ba8f5be 100644
--- a/llama/mllama.cpp
+++ b/llama/mllama.cpp
@@ -5,6 +5,7 @@
 #include "ggml-backend.h"
 #include "ggml-cpu.h"
 #include "ggml.h"
+#include "gguf.h"
 
 #ifdef GGML_USE_CUDA
 #include "ggml-cuda.h"
diff --git a/llama/patches/0001-cuda.patch b/llama/patches/0001-cuda.patch
index 0bf338f2..a766c30c 100644
--- a/llama/patches/0001-cuda.patch
+++ b/llama/patches/0001-cuda.patch
@@ -10,7 +10,7 @@ Subject: [PATCH] cuda
  3 files changed, 2 insertions(+), 1 deletion(-)
 
 diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
-index e2d6c405..a12172dc 100644
+index dba7be33..1ca40b2c 100644
 --- a/ggml/src/ggml-backend.cpp
 +++ b/ggml/src/ggml-backend.cpp
 @@ -106,7 +106,6 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
@@ -22,10 +22,10 @@ index e2d6c405..a12172dc 100644
  
  size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index 0b06be72..be29e979 100644
+index ebb2ccae..b094929b 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -424,6 +424,7 @@ struct ggml_backend_cuda_buffer_context {
+@@ -529,6 +529,7 @@ struct ggml_backend_cuda_buffer_context {
  static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
      ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
      delete ctx;
@@ -34,10 +34,10 @@ index 0b06be72..be29e979 100644
  
  static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
 diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
-index a85502ee..cd8ef741 100644
+index c550142a..fd9a4e77 100644
 --- a/ggml/src/ggml-metal/ggml-metal.m
 +++ b/ggml/src/ggml-metal/ggml-metal.m
-@@ -4187,6 +4187,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
+@@ -4350,6 +4350,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
      }
  
      free(ctx);
diff --git a/llama/patches/0002-pretokenizer.patch b/llama/patches/0002-pretokenizer.patch
index 189a996f..93a5ce59 100644
--- a/llama/patches/0002-pretokenizer.patch
+++ b/llama/patches/0002-pretokenizer.patch
@@ -4,17 +4,17 @@ Date: Mon, 16 Sep 2024 15:53:13 -0700
 Subject: [PATCH] pretokenizer
 
 ---
- src/llama-model.cpp | 14 +++-----------
+ src/llama-vocab.cpp | 14 +++-----------
  1 file changed, 3 insertions(+), 11 deletions(-)
 
-diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index 405e0528..00b80c52 100644
---- a/src/llama-model.cpp
-+++ b/src/llama-model.cpp
-@@ -1249,16 +1249,7 @@ void llm_load_vocab(llama_model_loader & ml, llama_model & model) {
-         if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
-             vocab.tokenizer_add_space_prefix = false;
-             vocab.tokenizer_clean_spaces = true;
+diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
+index ad9ffe66..a4eee9b8 100644
+--- a/src/llama-vocab.cpp
++++ b/src/llama-vocab.cpp
+@@ -1468,16 +1468,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+         if (type == LLAMA_VOCAB_TYPE_BPE) {
+             add_space_prefix = false;
+             clean_spaces = true;
 -            if (tokenizer_pre.empty()) {
 -                LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
 -                LLAMA_LOG_WARN("%s:                                             \n", __func__);
@@ -23,19 +23,19 @@ index 405e0528..00b80c52 100644
 -                LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL             \n", __func__);
 -                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
 -                LLAMA_LOG_WARN("%s:                                             \n", __func__);
--                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+-                pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
 -            } else if (tokenizer_pre == "default") {
 +            if (tokenizer_pre == "default") {
-                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+                 pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              } else if (
                      tokenizer_pre == "llama3"   ||
-@@ -1373,7 +1364,8 @@ void llm_load_vocab(llama_model_loader & ml, llama_model & model) {
+@@ -1593,7 +1584,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                  tokenizer_pre == "megrez") {
-                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+                 pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
              } else {
 -                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
 +                LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
-+                vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
++                pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
              }
-         } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
-             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+         } else if (type == LLAMA_VOCAB_TYPE_SPM) {
+             pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
diff --git a/llama/patches/0003-embeddings.patch b/llama/patches/0003-embeddings.patch
index c04ee563..176cb41d 100644
--- a/llama/patches/0003-embeddings.patch
+++ b/llama/patches/0003-embeddings.patch
@@ -9,10 +9,10 @@ Subject: [PATCH] embeddings
  2 files changed, 5 insertions(+), 3 deletions(-)
 
 diff --git a/src/llama-context.cpp b/src/llama-context.cpp
-index 38a55fb2..b9c4a5bf 100644
+index 671d2a81..47e79ed4 100644
 --- a/src/llama-context.cpp
 +++ b/src/llama-context.cpp
-@@ -475,7 +475,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
+@@ -479,7 +479,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
      const auto n_embd  = hparams.n_embd;
  
      // TODO: use a per-batch flag for logits presence instead
@@ -22,10 +22,10 @@ index 38a55fb2..b9c4a5bf 100644
  
      const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
 diff --git a/src/llama.cpp b/src/llama.cpp
-index ea78ea48..4eb3f6b9 100644
+index 607f2786..ac85bfed 100644
 --- a/src/llama.cpp
 +++ b/src/llama.cpp
-@@ -10876,7 +10876,6 @@ static int llama_decode_internal(
+@@ -8652,7 +8652,6 @@ static int llama_decode_impl(
              res  = nullptr;
              embd = nullptr;
          } else if (cparams.embeddings) {
@@ -33,7 +33,7 @@ index ea78ea48..4eb3f6b9 100644
              embd = nullptr;
              for (int i = ggml_graph_n_nodes(gf) - 1; i >= 0; --i) {
                  if (strcmp(ggml_graph_node(gf, i)->name, "result_embd_pooled") == 0) {
-@@ -10884,12 +10883,15 @@ static int llama_decode_internal(
+@@ -8660,12 +8659,15 @@ static int llama_decode_impl(
                      break;
                  }
              }
diff --git a/llama/patches/0004-clip-unicode.patch b/llama/patches/0004-clip-unicode.patch
index 9c90cfd0..50a12aad 100644
--- a/llama/patches/0004-clip-unicode.patch
+++ b/llama/patches/0004-clip-unicode.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] clip-unicode
  1 file changed, 39 insertions(+), 1 deletion(-)
 
 diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
-index 3cd0d2fa..b3c1829f 100644
+index 76d4a785..205af1eb 100644
 --- a/examples/llava/clip.cpp
 +++ b/examples/llava/clip.cpp
-@@ -56,6 +56,19 @@
+@@ -58,6 +58,19 @@
  #   define LOG_DBG(...) do { fprintf(stdout, __VA_ARGS__); } while (0)
  #endif // defined(LLAVA_LOG_OFF)
  
@@ -31,7 +31,7 @@ index 3cd0d2fa..b3c1829f 100644
  //#define CLIP_DEBUG_FUNCTIONS
  
  // RGB uint8 image
-@@ -1322,8 +1335,29 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
+@@ -1402,8 +1415,29 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
              gguf_free(ctx);
              return nullptr;
          }
@@ -62,7 +62,7 @@ index 3cd0d2fa..b3c1829f 100644
          if (!fin) {
              LOG_ERR("cannot open model file for loading tensors\n");
              clip_free(new_clip);
-@@ -1363,7 +1397,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
+@@ -1443,7 +1477,11 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
                  ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes);
              }
          }
diff --git a/llama/patches/0005-solar-pro.patch b/llama/patches/0005-solar-pro.patch
index 33be79c2..e201682b 100644
--- a/llama/patches/0005-solar-pro.patch
+++ b/llama/patches/0005-solar-pro.patch
@@ -11,21 +11,21 @@ tensor to store the scalar. the scalar is implemented a 1-dimensional
 tensor with 2 elements dervied from the model's bskcn_tv configuration.
 in general, the values are (bskcn_tv, 1 - bskcn_tv)
 ---
- src/llama-arch.cpp         |  53 +++++++----
+ src/llama-arch.cpp         |  21 +++++
  src/llama-arch.h           |   3 +
  src/llama-hparams.cpp      |   8 ++
- src/llama-hparams.h        |   5 +
+ src/llama-hparams.h        |   5 ++
  src/llama-model-loader.cpp |   1 +
- src/llama-model.cpp        |  16 ++++
+ src/llama-model.cpp        |  44 +++++++++++
  src/llama-model.h          |   3 +
- src/llama.cpp              | 185 +++++++++++++++++++++++++++++++++++++
- 8 files changed, 258 insertions(+), 16 deletions(-)
+ src/llama.cpp              | 152 ++++++++++++++++++++++++++++++++++++-
+ 8 files changed, 236 insertions(+), 1 deletion(-)
 
 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index 007d79f8..5b376c5e 100644
+index 97a1e7e5..a1e0ebcc 100644
 --- a/src/llama-arch.cpp
 +++ b/src/llama-arch.cpp
-@@ -59,6 +59,7 @@ static const std::map LLM_ARCH_NAMES = {
+@@ -61,6 +61,7 @@ static const std::map LLM_ARCH_NAMES = {
      { LLM_ARCH_GRANITE,          "granite"          },
      { LLM_ARCH_GRANITE_MOE,      "granitemoe"       },
      { LLM_ARCH_CHAMELEON,        "chameleon"        },
@@ -33,48 +33,16 @@ index 007d79f8..5b376c5e 100644
      { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
      { LLM_ARCH_UNKNOWN,          "(unknown)"        },
  };
-@@ -106,22 +107,23 @@ static const std::map LLM_KV_NAMES = {
-     { LLM_KV_RESIDUAL_SCALE,                    "%s.residual_scale"                    },
-     { LLM_KV_EMBEDDING_SCALE,                   "%s.embedding_scale"                   },
- 
--    { LLM_KV_ATTENTION_HEAD_COUNT,             "%s.attention.head_count"             },
--    { LLM_KV_ATTENTION_HEAD_COUNT_KV,          "%s.attention.head_count_kv"          },
--    { LLM_KV_ATTENTION_MAX_ALIBI_BIAS,         "%s.attention.max_alibi_bias"         },
--    { LLM_KV_ATTENTION_CLAMP_KQV,              "%s.attention.clamp_kqv"              },
--    { LLM_KV_ATTENTION_KEY_LENGTH,             "%s.attention.key_length"             },
--    { LLM_KV_ATTENTION_VALUE_LENGTH,           "%s.attention.value_length"           },
--    { LLM_KV_ATTENTION_LAYERNORM_EPS,          "%s.attention.layer_norm_epsilon"     },
--    { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,      "%s.attention.layer_norm_rms_epsilon" },
--    { LLM_KV_ATTENTION_GROUPNORM_EPS,          "%s.attention.group_norm_epsilon"     },
--    { LLM_KV_ATTENTION_GROUPNORM_GROUPS,       "%s.attention.group_norm_groups"      },
--    { LLM_KV_ATTENTION_CAUSAL,                 "%s.attention.causal"                 },
--    { LLM_KV_ATTENTION_Q_LORA_RANK,            "%s.attention.q_lora_rank"            },
--    { LLM_KV_ATTENTION_KV_LORA_RANK,           "%s.attention.kv_lora_rank"           },
--    { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
--    { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
--    { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
-+    { LLM_KV_ATTENTION_HEAD_COUNT,               "%s.attention.head_count"               },
-+    { LLM_KV_ATTENTION_HEAD_COUNT_KV,            "%s.attention.head_count_kv"            },
-+    { LLM_KV_ATTENTION_MAX_ALIBI_BIAS,           "%s.attention.max_alibi_bias"           },
-+    { LLM_KV_ATTENTION_CLAMP_KQV,                "%s.attention.clamp_kqv"                },
-+    { LLM_KV_ATTENTION_KEY_LENGTH,               "%s.attention.key_length"               },
-+    { LLM_KV_ATTENTION_VALUE_LENGTH,             "%s.attention.value_length"             },
-+    { LLM_KV_ATTENTION_LAYERNORM_EPS,            "%s.attention.layer_norm_epsilon"       },
-+    { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,        "%s.attention.layer_norm_rms_epsilon"   },
-+    { LLM_KV_ATTENTION_GROUPNORM_EPS,            "%s.attention.group_norm_epsilon"       },
-+    { LLM_KV_ATTENTION_GROUPNORM_GROUPS,         "%s.attention.group_norm_groups"        },
-+    { LLM_KV_ATTENTION_CAUSAL,                   "%s.attention.causal"                   },
-+    { LLM_KV_ATTENTION_Q_LORA_RANK,              "%s.attention.q_lora_rank"              },
-+    { LLM_KV_ATTENTION_KV_LORA_RANK,             "%s.attention.kv_lora_rank"             },
-+    { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,   "%s.attention.relative_buckets_count"   },
-+    { LLM_KV_ATTENTION_SLIDING_WINDOW,           "%s.attention.sliding_window"           },
-+    { LLM_KV_ATTENTION_SCALE,                    "%s.attention.scale"                    },
-+    { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,    "%s.attention.block_skip_connection"    },
+@@ -125,6 +126,7 @@ static const std::map LLM_KV_NAMES = {
+     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
+     { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
+     { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
++    { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,  "%s.attention.block_skip_connection"  },
  
      { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
      { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
-@@ -1240,6 +1242,24 @@ static const std::map> LLM_TENSOR_N
-             { LLM_TENSOR_POS_NET_ATTN_OUT,  "posnet.%d.attn_output" },
+@@ -1271,6 +1273,24 @@ static const std::map> LLM_TENSOR_N
+             { LLM_TENSOR_ATTN_K_NORM,     "blk.%d.attn_k_norm" },
          },
      },
 +    {
@@ -96,9 +64,9 @@ index 007d79f8..5b376c5e 100644
 +        },
 +    },
      {
-         LLM_ARCH_UNKNOWN,
+         LLM_ARCH_WAVTOKENIZER_DEC,
          {
-@@ -1372,6 +1392,7 @@ static const std::map LLM_TENSOR_INFOS = {
+@@ -1429,6 +1449,7 @@ static const std::map LLM_TENSOR_INFOS = {
      {LLM_TENSOR_FFN_EXP_PROBS_B,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
      // this tensor is loaded for T5, but never used
      {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
@@ -107,10 +75,10 @@ index 007d79f8..5b376c5e 100644
      {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
      {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 diff --git a/src/llama-arch.h b/src/llama-arch.h
-index 45e458bb..eac7055b 100644
+index 122fdceb..77919578 100644
 --- a/src/llama-arch.h
 +++ b/src/llama-arch.h
-@@ -63,6 +63,7 @@ enum llm_arch {
+@@ -65,6 +65,7 @@ enum llm_arch {
      LLM_ARCH_GRANITE,
      LLM_ARCH_GRANITE_MOE,
      LLM_ARCH_CHAMELEON,
@@ -118,7 +86,7 @@ index 45e458bb..eac7055b 100644
      LLM_ARCH_WAVTOKENIZER_DEC,
      LLM_ARCH_UNKNOWN,
  };
-@@ -126,6 +127,7 @@ enum llm_kv {
+@@ -129,6 +130,7 @@ enum llm_kv {
      LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
      LLM_KV_ATTENTION_SLIDING_WINDOW,
      LLM_KV_ATTENTION_SCALE,
@@ -126,7 +94,7 @@ index 45e458bb..eac7055b 100644
  
      LLM_KV_ROPE_DIMENSION_COUNT,
      LLM_KV_ROPE_DIMENSION_SECTIONS,
-@@ -305,6 +307,7 @@ enum llm_tensor {
+@@ -311,6 +313,7 @@ enum llm_tensor {
      LLM_TENSOR_ENC_OUTPUT_NORM,
      LLM_TENSOR_CLS,
      LLM_TENSOR_CLS_OUT,
@@ -135,7 +103,7 @@ index 45e458bb..eac7055b 100644
      LLM_TENSOR_CONVNEXT_DW,
      LLM_TENSOR_CONVNEXT_NORM,
 diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
-index c4053469..450738da 100644
+index ea87b295..f3955de9 100644
 --- a/src/llama-hparams.cpp
 +++ b/src/llama-hparams.cpp
 @@ -69,3 +69,11 @@ uint32_t llama_hparams::n_embd_v_s() const {
@@ -152,10 +120,10 @@ index c4053469..450738da 100644
 +}
 \ No newline at end of file
 diff --git a/src/llama-hparams.h b/src/llama-hparams.h
-index a29f20ec..fd898e27 100644
+index 1fe45410..1bdcdfd5 100644
 --- a/src/llama-hparams.h
 +++ b/src/llama-hparams.h
-@@ -52,6 +52,8 @@ struct llama_hparams {
+@@ -50,6 +50,8 @@ struct llama_hparams {
      std::array n_head_kv_arr;
      std::array n_ff_arr;
  
@@ -164,7 +132,7 @@ index a29f20ec..fd898e27 100644
      uint32_t n_layer_dense_lead = 0;
      uint32_t n_lora_q           = 0;
      uint32_t n_lora_kv          = 0;
-@@ -134,6 +136,9 @@ struct llama_hparams {
+@@ -133,6 +135,9 @@ struct llama_hparams {
  
      // dimension of the recurrent state embeddings
      uint32_t n_embd_v_s() const;
@@ -175,23 +143,23 @@ index a29f20ec..fd898e27 100644
  
  static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable");
 diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
-index 7743b465..422524a8 100644
+index 05d58ad9..1252aca1 100644
 --- a/src/llama-model-loader.cpp
 +++ b/src/llama-model-loader.cpp
-@@ -364,6 +364,7 @@ namespace GGUFMeta {
+@@ -439,6 +439,7 @@ namespace GGUFMeta {
      // TODO: this is not very clever - figure out something better
      template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required);
      template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required);
 +    template bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required);
  
- llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, bool check_tensors, const struct llama_model_kv_override * param_overrides_p) {
-     int trace = 0;
+ llama_model_loader::llama_model_loader(
+         const std::string & fname,
 diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index 00b80c52..306c557d 100644
+index 36a0a009..ad1315c6 100644
 --- a/src/llama-model.cpp
 +++ b/src/llama-model.cpp
-@@ -1091,6 +1091,21 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
-                     default: model.type = e_model::MODEL_UNKNOWN;
+@@ -1238,6 +1238,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+                     default: type = LLM_TYPE_UNKNOWN;
                 }
              } break;
 +        case LLM_ARCH_SOLAR:
@@ -200,52 +168,19 @@ index 00b80c52..306c557d 100644
 +                for (size_t i = 0; i < hparams.n_bskcn_arr.max_size(); ++i) {
 +                    auto & bskcn = hparams.n_bskcn_arr[i];
 +                    bskcn.fill(0);
-+                    auto kv = LLM_KV(model.arch);
++                    auto kv = LLM_KV(arch);
 +                    ml.get_key_or_arr(format((kv(LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION) + ".%d").c_str(), i), bskcn, hparams.n_layer, false);
 +                }
 +
 +                switch (hparams.n_layer) {
-+                    case 64: model.type = e_model::MODEL_22B; break;
-+                    default: model.type = e_model::MODEL_UNKNOWN;
++                    case 64: type = LLM_TYPE_22B; break;
++                    default: type = LLM_TYPE_UNKNOWN;
 +                }
 +            } break;
          case LLM_ARCH_WAVTOKENIZER_DEC:
              {
                  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
-@@ -2065,6 +2080,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
-         case LLM_ARCH_GRANITE:
-         case LLM_ARCH_GRANITE_MOE:
-         case LLM_ARCH_CHAMELEON:
-+        case LLM_ARCH_SOLAR:
-             return LLAMA_ROPE_TYPE_NORM;
- 
-         // the pairs of head values are offset by n_rot/2
-diff --git a/src/llama-model.h b/src/llama-model.h
-index ce038932..c1b9c0a1 100644
---- a/src/llama-model.h
-+++ b/src/llama-model.h
-@@ -54,6 +54,7 @@ enum llm_type {
-     MODEL_15B,
-     MODEL_16B,
-     MODEL_20B,
-+    MODEL_22B,
-     MODEL_30B,
-     MODEL_32B,
-     MODEL_34B,
-@@ -275,6 +276,8 @@ struct llama_layer {
-     struct ggml_tensor * ffn_up_scale   = nullptr;
-     struct ggml_tensor * ffn_down_scale = nullptr;
- 
-+    struct ggml_tensor * bskcn_tv = nullptr;
-+
-     struct llama_layer_posnet posnet;
- 
-     struct llama_layer_convnext convnext;
-diff --git a/src/llama.cpp b/src/llama.cpp
-index 4eb3f6b9..7dec50ae 100644
---- a/src/llama.cpp
-+++ b/src/llama.cpp
-@@ -2206,6 +2206,35 @@ static bool llm_load_tensors(
+@@ -3316,6 +3331,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
  
                          layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
  
@@ -256,16 +191,16 @@ index 4eb3f6b9..7dec50ae 100644
 +                } break;
 +            case LLM_ARCH_SOLAR:
 +                {
-+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
++                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 +
 +                    // output
 +                    {
-+                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-+                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
++                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
++                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
 +                    }
 +
 +                    for (int i = 0; i < n_layer; ++i) {
-+                        auto & layer = model.layers[i];
++                        auto & layer = layers[i];
 +
 +                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 +
@@ -277,16 +212,53 @@ index 4eb3f6b9..7dec50ae 100644
 +                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 +
 +                        layer.bskcn_tv = create_tensor(tn(LLM_TENSOR_BSKCN_TV, "weight", i), {2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-+
                          layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
                          layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                          layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-@@ -10226,6 +10255,158 @@ struct llm_build_context {
-         return gf;
-     }
+@@ -3900,6 +3943,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
+         case LLM_ARCH_GRANITE:
+         case LLM_ARCH_GRANITE_MOE:
+         case LLM_ARCH_CHAMELEON:
++        case LLM_ARCH_SOLAR:
+             return LLAMA_ROPE_TYPE_NORM;
  
-+    ggml_cgraph * build_solar() {
-+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+         // the pairs of head values are offset by n_rot/2
+diff --git a/src/llama-model.h b/src/llama-model.h
+index a7c30444..1afb0024 100644
+--- a/src/llama-model.h
++++ b/src/llama-model.h
+@@ -55,6 +55,7 @@ enum llm_type {
+     LLM_TYPE_15B,
+     LLM_TYPE_16B,
+     LLM_TYPE_20B,
++    LLM_TYPE_22B,
+     LLM_TYPE_30B,
+     LLM_TYPE_32B,
+     LLM_TYPE_34B,
+@@ -281,6 +282,8 @@ struct llama_layer {
+     struct ggml_tensor * ffn_up_scale   = nullptr;
+     struct ggml_tensor * ffn_down_scale = nullptr;
+ 
++    struct ggml_tensor * bskcn_tv = nullptr;
++
+     struct llama_layer_posnet posnet;
+ 
+     struct llama_layer_convnext convnext;
+diff --git a/src/llama.cpp b/src/llama.cpp
+index ac85bfed..6d320ea4 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -7953,9 +7953,155 @@ struct llm_build_context {
+         cb(img_logits, "img_logits", -1);
+         cur = ggml_set_1d(ctx0, cur, img_logits, ggml_element_size(cur) * img_token_start_idx);
+         cb(cur, "result_output", -1);
+-
+         ggml_build_forward_expand(gf, cur);
++        return gf;
++   }
++
++   ggml_cgraph * build_solar() {
++        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 +
 +        // mutable variable, needed during the last layer of the computation to skip unused tokens
 +        int32_t n_tokens = this->n_tokens;
@@ -333,7 +305,7 @@ index 4eb3f6b9..7dec50ae 100644
 +                   ggml_mul(ctx0, bskcn_2, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, 0)),
 +                   ggml_mul(ctx0, inpSA, ggml_view_1d(ctx0, model.layers[il].bskcn_tv, 1, ggml_element_size(model.layers[il].bskcn_tv))));
 +            }
-+
+ 
 +            // norm
 +            cur = llm_build_norm(ctx0, inpL, hparams,
 +                    model.layers[il].attn_norm, NULL,
@@ -422,25 +394,18 @@ index 4eb3f6b9..7dec50ae 100644
 +        }
 +
 +        cur = inpL;
-+
 +        cur = llm_build_norm(ctx0, cur, hparams,
 +                model.output_norm, NULL,
 +                LLM_NORM_RMS, cb, -1);
 +        cb(cur, "result_norm", -1);
-+
 +        // lm_head
 +        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
 +        cb(cur, "result_output", -1);
-+
 +        ggml_build_forward_expand(gf, cur);
-+
-+        return gf;
-+    }
-+
-     struct ggml_cgraph * build_wavtokenizer_dec() {
-         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+         return gf;
+     }
  
-@@ -10660,6 +10841,10 @@ static struct ggml_cgraph * llama_build_graph(
+@@ -8398,6 +8544,10 @@ static struct ggml_cgraph * llama_build_graph(
              {
                  result = llm.build_chameleon();
              } break;
diff --git a/llama/patches/0006-conditional-fattn.patch b/llama/patches/0006-conditional-fattn.patch
index 73990578..63af1f5c 100644
--- a/llama/patches/0006-conditional-fattn.patch
+++ b/llama/patches/0006-conditional-fattn.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] conditional-fattn
  1 file changed, 2 insertions(+)
 
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index be29e979..aaa79ea4 100644
+index b094929b..36165840 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -2159,9 +2159,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
+@@ -2282,9 +2282,11 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
          case GGML_OP_ARGSORT:
              ggml_cuda_op_argsort(ctx, dst);
              break;
diff --git a/llama/patches/0007-add-mllama-support.patch b/llama/patches/0007-add-mllama-support.patch
index 678dabad..efed0923 100644
--- a/llama/patches/0007-add-mllama-support.patch
+++ b/llama/patches/0007-add-mllama-support.patch
@@ -15,27 +15,27 @@ remaining is to implement the cross attention mask
  examples/llava/llava.cpp      |   5 +-
  ggml/src/ggml-backend-reg.cpp |   6 +-
  include/llama.h               |   6 +
- src/llama-arch.cpp            |  44 +++++
+ src/llama-arch.cpp            |  44 ++++++
  src/llama-arch.h              |  10 ++
  src/llama-batch.cpp           |   3 +
- src/llama-context.cpp         |  19 ++-
+ src/llama-context.cpp         |  28 ++--
  src/llama-context.h           |   2 +
  src/llama-cparams.h           |   1 +
- src/llama-hparams.cpp         |   8 +-
- src/llama-hparams.h           |   4 +
- src/llama-kv-cache.cpp        |  33 ++++
+ src/llama-hparams.cpp         |   6 +
+ src/llama-hparams.h           |   5 +
+ src/llama-kv-cache.cpp        |  13 +-
  src/llama-model-loader.cpp    |   2 +
- src/llama-model.cpp           |  59 ++-----
- src/llama-model.h             |  51 ++++++
+ src/llama-model.cpp           |  65 ++++++++-
+ src/llama-model.h             |  12 ++
  src/llama-quant.cpp           |   4 +-
- src/llama.cpp                 | 307 +++++++++++++++++++++++++++++++++-
- 17 files changed, 508 insertions(+), 56 deletions(-)
+ src/llama.cpp                 | 262 +++++++++++++++++++++++++++++++++-
+ 17 files changed, 452 insertions(+), 22 deletions(-)
 
 diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp
-index 16f30c56..0f0f3f62 100644
+index 518aad3f..f0e484a1 100644
 --- a/examples/llava/llava.cpp
 +++ b/examples/llava/llava.cpp
-@@ -429,7 +429,7 @@ struct llava_embd_batch {
+@@ -445,7 +445,7 @@ struct llava_embd_batch {
      std::vector seq_ids;
      std::vector         logits;
      llama_batch batch;
@@ -44,7 +44,7 @@ index 16f30c56..0f0f3f62 100644
          pos     .resize(n_tokens);
          n_seq_id.resize(n_tokens);
          seq_ids .resize(n_tokens + 1);
-@@ -441,6 +441,7 @@ struct llava_embd_batch {
+@@ -457,6 +457,7 @@ struct llava_embd_batch {
              /*n_tokens       =*/ n_tokens,
              /*tokens         =*/ nullptr,
              /*embd           =*/ embd,
@@ -52,7 +52,7 @@ index 16f30c56..0f0f3f62 100644
              /*pos            =*/ pos.data(),
              /*n_seq_id       =*/ n_seq_id.data(),
              /*seq_id         =*/ seq_ids.data(),
-@@ -464,7 +465,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
+@@ -480,7 +481,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
              n_eval = n_batch;
          }
          float * embd = image_embed->embed+i*n_embd;
@@ -62,7 +62,7 @@ index 16f30c56..0f0f3f62 100644
              LOG_ERR("%s : failed to eval\n", __func__);
              return false;
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 7ddd178b..899d16f2 100644
+index 955ed505..95036ef8 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
 @@ -171,9 +171,9 @@ struct ggml_backend_registry {
@@ -79,10 +79,10 @@ index 7ddd178b..899d16f2 100644
          register_backend(ggml_backend_rpc_reg());
  #endif
 diff --git a/include/llama.h b/include/llama.h
-index a0d5ba5d..9f411960 100644
+index 47919602..cc948005 100644
 --- a/include/llama.h
 +++ b/include/llama.h
-@@ -250,6 +250,7 @@ extern "C" {
+@@ -249,6 +249,7 @@ extern "C" {
  
          llama_token  *  token;
          float        *  embd;
@@ -90,7 +90,7 @@ index a0d5ba5d..9f411960 100644
          llama_pos    *  pos;
          int32_t      *  n_seq_id;
          llama_seq_id ** seq_id;
-@@ -347,6 +348,7 @@ extern "C" {
+@@ -343,6 +344,7 @@ extern "C" {
          bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
          bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
          bool no_perf;     // whether to measure performance timings
@@ -98,9 +98,9 @@ index a0d5ba5d..9f411960 100644
  
          // Abort callback
          // if it returns true, execution of llama_decode() will be aborted
-@@ -426,6 +428,10 @@ extern "C" {
-                      struct llama_model * model,
-             struct llama_context_params   params);
+@@ -443,6 +445,10 @@ extern "C" {
+             struct llama_context_params   params),
+             "use llama_init_from_model instead");
  
 +    // TODO (jmorganca): this should most likely be passed in as part of a batch
 +    // and not set on the context for all batches.
@@ -110,7 +110,7 @@ index a0d5ba5d..9f411960 100644
      LLAMA_API void llama_free(struct llama_context * ctx);
  
 diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
-index 5b376c5e..b35aeb31 100644
+index a1e0ebcc..b6f20286 100644
 --- a/src/llama-arch.cpp
 +++ b/src/llama-arch.cpp
 @@ -6,6 +6,7 @@
@@ -121,15 +121,15 @@ index 5b376c5e..b35aeb31 100644
      { LLM_ARCH_DECI,             "deci"             },
      { LLM_ARCH_FALCON,           "falcon"           },
      { LLM_ARCH_GROK,             "grok"             },
-@@ -124,6 +125,7 @@ static const std::map LLM_KV_NAMES = {
-     { LLM_KV_ATTENTION_SLIDING_WINDOW,           "%s.attention.sliding_window"           },
-     { LLM_KV_ATTENTION_SCALE,                    "%s.attention.scale"                    },
-     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,    "%s.attention.block_skip_connection"    },
-+    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,   "%s.attention.cross_attention_layers"   },
+@@ -127,6 +128,7 @@ static const std::map LLM_KV_NAMES = {
+     { LLM_KV_ATTENTION_SLIDING_WINDOW,         "%s.attention.sliding_window"         },
+     { LLM_KV_ATTENTION_SCALE,                  "%s.attention.scale"                  },
+     { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,  "%s.attention.block_skip_connection"  },
++    { LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
  
      { LLM_KV_ROPE_DIMENSION_COUNT,      "%s.rope.dimension_count"                 },
      { LLM_KV_ROPE_DIMENSION_SECTIONS,   "%s.rope.dimension_sections"              },
-@@ -220,6 +222,40 @@ static const std::map> LLM_TENSOR_N
+@@ -225,6 +227,40 @@ static const std::map> LLM_TENSOR_N
              { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
          },
      },
@@ -170,7 +170,7 @@ index 5b376c5e..b35aeb31 100644
      {
          LLM_ARCH_DECI,
          {
-@@ -1393,6 +1429,14 @@ static const std::map LLM_TENSOR_INFOS = {
+@@ -1450,6 +1486,14 @@ static const std::map LLM_TENSOR_INFOS = {
      // this tensor is loaded for T5, but never used
      {LLM_TENSOR_DEC_CROSS_ATTN_REL_B,       {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
      {LLM_TENSOR_BSKCN_TV,                   {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -186,7 +186,7 @@ index 5b376c5e..b35aeb31 100644
      {LLM_TENSOR_POS_NET_NORM,               {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
      {LLM_TENSOR_POS_NET_NORM1,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 diff --git a/src/llama-arch.h b/src/llama-arch.h
-index eac7055b..e8235ae0 100644
+index 77919578..ec742224 100644
 --- a/src/llama-arch.h
 +++ b/src/llama-arch.h
 @@ -10,6 +10,7 @@
@@ -197,7 +197,7 @@ index eac7055b..e8235ae0 100644
      LLM_ARCH_DECI,
      LLM_ARCH_FALCON,
      LLM_ARCH_BAICHUAN,
-@@ -128,6 +129,7 @@ enum llm_kv {
+@@ -131,6 +132,7 @@ enum llm_kv {
      LLM_KV_ATTENTION_SLIDING_WINDOW,
      LLM_KV_ATTENTION_SCALE,
      LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
@@ -205,7 +205,7 @@ index eac7055b..e8235ae0 100644
  
      LLM_KV_ROPE_DIMENSION_COUNT,
      LLM_KV_ROPE_DIMENSION_SECTIONS,
-@@ -308,6 +310,14 @@ enum llm_tensor {
+@@ -314,6 +316,14 @@ enum llm_tensor {
      LLM_TENSOR_CLS,
      LLM_TENSOR_CLS_OUT,
      LLM_TENSOR_BSKCN_TV,
@@ -249,10 +249,10 @@ index 01d5ca57..8682b0e6 100644
          batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
      }
 diff --git a/src/llama-context.cpp b/src/llama-context.cpp
-index b9c4a5bf..9d0e7ca3 100644
+index 47e79ed4..7b22fe13 100644
 --- a/src/llama-context.cpp
 +++ b/src/llama-context.cpp
-@@ -71,10 +71,19 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
+@@ -74,10 +74,19 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
      }
  
      if (ubatch.embd) {
@@ -275,7 +275,30 @@ index b9c4a5bf..9d0e7ca3 100644
      }
  
      if (ubatch.pos && lctx.inp_pos) {
-@@ -653,6 +662,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
+@@ -470,12 +479,11 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
+ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
+     const auto & cparams = lctx.cparams;
+     const auto & hparams = lctx.model.hparams;
+-    const auto & vocab   = lctx.model.vocab;
+ 
+     const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
+ 
+     const auto n_batch = cparams.n_batch;
+-    const auto n_vocab = vocab.n_tokens();
++    const auto n_vocab = hparams.n_vocab;
+     const auto n_embd  = hparams.n_embd;
+ 
+     // TODO: use a per-batch flag for logits presence instead
+@@ -542,7 +550,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
+ void llama_output_reorder(struct llama_context & ctx) {
+     std::vector & out_ids = ctx.sbatch.out_ids;
+     if (!out_ids.empty()) {
+-        const uint32_t n_vocab = ctx.model.vocab.n_tokens();
++        const uint32_t n_vocab = ctx.model.hparams.n_vocab;
+         const uint32_t n_embd  = ctx.model.hparams.n_embd;
+ 
+         const int32_t n_outputs = ctx.n_outputs;
+@@ -657,6 +665,10 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
      ctx->cparams.causal_attn = causal_attn;
  }
  
@@ -286,8 +309,26 @@ index b9c4a5bf..9d0e7ca3 100644
  void llama_synchronize(struct llama_context * ctx) {
      ggml_backend_sched_synchronize(ctx->sched.get());
  
+@@ -726,7 +738,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
+             throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
+         }
+ 
+-        return ctx->logits + j*ctx->model.vocab.n_tokens();
++        return ctx->logits + j*ctx->model.hparams.n_vocab;
+     } catch (const std::exception & err) {
+         LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
+ #ifndef NDEBUG
+@@ -886,7 +898,7 @@ struct llama_data_write {
+     }
+ 
+     void write_logits(const struct llama_context * ctx) {
+-        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens());
++        const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
+ 
+         write(&logits_size, sizeof(logits_size));
+ 
 diff --git a/src/llama-context.h b/src/llama-context.h
-index 0d163c47..4980a60e 100644
+index a9268b29..cf12c9d7 100644
 --- a/src/llama-context.h
 +++ b/src/llama-context.h
 @@ -107,6 +107,8 @@ struct llama_context {
@@ -312,7 +353,7 @@ index 252012f3..9681e5a0 100644
      enum llama_pooling_type pooling_type;
  
 diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
-index 450738da..42f8a58f 100644
+index f3955de9..0b841028 100644
 --- a/src/llama-hparams.cpp
 +++ b/src/llama-hparams.cpp
 @@ -2,6 +2,8 @@
@@ -328,18 +369,25 @@ index 450738da..42f8a58f 100644
      }
  
      GGML_ABORT("fatal error");
--}
-\ No newline at end of file
 +}
 +
 +bool llama_hparams::cross_attention_layers(uint32_t il) const {
 +    return std::find(cross_attn_layers.begin(), cross_attn_layers.end(), il) != cross_attn_layers.end();
-+}
+ }
+\ No newline at end of file
 diff --git a/src/llama-hparams.h b/src/llama-hparams.h
-index fd898e27..f826cd9a 100644
+index 1bdcdfd5..05383046 100644
 --- a/src/llama-hparams.h
 +++ b/src/llama-hparams.h
-@@ -53,6 +53,7 @@ struct llama_hparams {
+@@ -41,6 +41,7 @@ struct llama_hparams {
+     uint32_t n_expert = 0;
+     uint32_t n_expert_used = 0;
+     uint32_t n_rel_attn_bkts = 0;
++    uint32_t n_vocab = 0;
+ 
+     // for WavTokenizer
+     struct llama_hparams_posnet   posnet;
+@@ -51,6 +52,7 @@ struct llama_hparams {
      std::array n_ff_arr;
  
      std::array, 4> n_bskcn_arr = {};
@@ -347,65 +395,45 @@ index fd898e27..f826cd9a 100644
  
      uint32_t n_layer_dense_lead = 0;
      uint32_t n_lora_q           = 0;
-@@ -139,6 +140,9 @@ struct llama_hparams {
+@@ -138,6 +140,9 @@ struct llama_hparams {
  
      // Block skip connection
      bool n_bskcn(uint32_t n, uint32_t il) const;
 +
-+    // cross attention layers   
++    // cross attention layers
 +    bool cross_attention_layers(uint32_t il) const;
  };
  
  static_assert(std::is_trivially_copyable::value, "llama_hparams must be trivially copyable");
 diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
-index 53379253..cf814dbe 100644
+index feffdf0d..b541c5a3 100644
 --- a/src/llama-kv-cache.cpp
 +++ b/src/llama-kv-cache.cpp
-@@ -72,6 +72,39 @@ bool llama_kv_cache_init(
-     cache.v_l.reserve(n_layer);
+@@ -91,8 +91,17 @@ bool llama_kv_cache_init(
+             return false;
+         }
  
-     for (int i = 0; i < n_layer; i++) {
+-        ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
+-        ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
++        ggml_tensor * k, *v;
++
 +        // for cross attention layers
 +        if (model.arch == LLM_ARCH_MLLAMA && hparams.cross_attention_layers(i)) {
-+            const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
-+            const llama_model::buft_list_t * buft_list;
-+            if (offload) {
-+                buft_list = model.dev_layer.at(i).buft_list;
-+            } else {
-+                buft_list = &model.cpu_buft_list;
-+            }
-+            ggml_backend_buffer_type_t buft = select_buft(*buft_list,
-+                [&](ggml_context * ctx) {
-+                    ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
-+                    if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) {
-+                        return k;
-+                    }
-+                    ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1);
-+                    return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type);
-+                });
-+            ggml_context * ctx = ctx_for_buft(buft);
-+
-+            if (!ctx) {
-+                LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__);
-+                return false;
-+            }
-+            ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
-+            ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
-+            ggml_format_name(k, "cache_k_l%d", i);
-+            ggml_format_name(v, "cache_v_l%d", i);
-+            cache.k_l.push_back(k);
-+            cache.v_l.push_back(v);
-+            continue;
++            k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_k, 6404, hparams.n_head_kv(i));
++            v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hparams.n_embd_head_v, 6404, hparams.n_head_kv(i));
++        } else {
++            k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
++            v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
 +        }
 +
-         const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
-         const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
- 
+         ggml_format_name(k, "cache_k_l%d", i);
+         ggml_format_name(v, "cache_v_l%d", i);
+         cache.k_l.push_back(k);
 diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
-index 422524a8..b12d6566 100644
+index 1252aca1..45d08721 100644
 --- a/src/llama-model-loader.cpp
 +++ b/src/llama-model-loader.cpp
-@@ -240,6 +240,8 @@ namespace GGUFMeta {
+@@ -315,6 +315,8 @@ namespace GGUFMeta {
          return true;
      }
  
@@ -415,80 +443,47 @@ index 422524a8..b12d6566 100644
      bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) {
          const int kid = gguf_find_key(meta.get(), key.c_str());
 diff --git a/src/llama-model.cpp b/src/llama-model.cpp
-index 306c557d..4f9bbf90 100644
+index ad1315c6..21819080 100644
 --- a/src/llama-model.cpp
 +++ b/src/llama-model.cpp
-@@ -146,46 +146,6 @@ std::string llama_model_ftype_name(const llama_model & model) {
-     return llama_model_ftype_name(model.ftype);
- }
+@@ -401,6 +401,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
  
--template
--static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
--    ggml_init_params params = {
--        /*.mem_size   =*/ ggml_tensor_overhead()*8,
--        /*.mem_buffer =*/ NULL,
--        /*.no_alloc   =*/ true,
--    };
--
--    ggml_context_ptr ctx { ggml_init(params) };
--    if (!ctx) {
--        throw std::runtime_error(format("failed to create ggml context"));
--    }
--
--    ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
--    ggml_tensor * op_tensor = fn(ctx.get());
--    for (int i = 0; i < GGML_MAX_SRC; i++) {
--        if (op_tensor->src[i] != nullptr) {
--            assert(op_tensor->src[i]->buffer == nullptr);
--            op_tensor->src[i]->buffer = buf.get();
--        }
--    }
--
--    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
--
--    return op_supported;
--}
--
--template
--static ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) {
--    for (const auto & cur : buft_list) {
--        ggml_backend_dev_t cur_dev = cur.first;
--        ggml_backend_buffer_type_t cur_buft = cur.second;
--        if (buft_supported(cur_buft, cur_dev, fn)) {
--            return cur_buft;
--        }
--    }
--
--    throw std::runtime_error(format("no suitable buffer type found"));
--}
--
- ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il) {
-     return select_buft(
-             *model.dev_layer.at(il).buft_list,
-@@ -312,9 +272,11 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
+     // get general kv
+     ml.get_key(LLM_KV_GENERAL_NAME, name, false);
++    ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false);
+ 
+     // everything past this point is not vocab-related
+     if (hparams.vocab_only) {
+@@ -412,6 +413,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
+     ml.get_key(LLM_KV_BLOCK_COUNT,       hparams.n_layer);
+     ml.get_key(LLM_KV_EXPERT_COUNT,      hparams.n_expert,      false);
+     ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false);
++    ml.get_key(LLM_KV_VOCAB_SIZE,        hparams.n_vocab,       false);
+ 
+     if (arch == LLM_ARCH_WAVTOKENIZER_DEC) {
+         ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features);
+@@ -435,9 +437,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
      std::fill(hparams.n_head_arr.begin(),    hparams.n_head_arr.end(),    0);
      std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
      std::fill(hparams.n_ff_arr.begin(),      hparams.n_ff_arr.end(),      0);
 +    std::fill(hparams.cross_attn_layers.begin(), hparams.cross_attn_layers.end(), -1);
  
--    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer, false);
--    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
-+    ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,       hparams.n_ff_arr,   hparams.n_layer, false);
-+    ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT,      hparams.n_head_arr, hparams.n_layer, false);
+     ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH,  hparams.n_ff_arr,   hparams.n_layer, false);
+     ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false);
 +    ml.get_arr(LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, hparams.cross_attn_layers, false);
  
      // n_head_kv is optional, default to n_head
      hparams.n_head_kv_arr = hparams.n_head_arr;
-@@ -363,7 +325,7 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
+@@ -486,7 +490,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
  
          ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
  
--        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) {
-+        if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_MLLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) {
+-        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
++        if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_MLLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) {
              if (hparams.n_rot != hparams.n_embd_head_k) {
                  throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
              }
-@@ -405,6 +367,16 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
+@@ -530,6 +534,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                      }
                  }
              } break;
@@ -497,145 +492,44 @@ index 306c557d..4f9bbf90 100644
 +                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
 +
 +                switch (hparams.n_layer) {
-+                    case 40: model.type = e_model::MODEL_11B; break;
-+                    case 100: model.type = e_model::MODEL_90B; break;
-+                    default: model.type = e_model::MODEL_UNKNOWN;
++                    case 40: type = LLM_TYPE_11B; break;
++                    case 100: type = LLM_TYPE_90B; break;
++                    default: type = LLM_TYPE_UNKNOWN;
 +                }
 +            } break;
          case LLM_ARCH_DECI:
              {
                  ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-@@ -2062,6 +2034,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
- 
-         // use what we call a normal RoPE, operating on pairs of consecutive head values
-         case LLM_ARCH_LLAMA:
-+        case LLM_ARCH_MLLAMA:
-         case LLM_ARCH_DECI:
-         case LLM_ARCH_BAICHUAN:
-         case LLM_ARCH_STARCODER:
-diff --git a/src/llama-model.h b/src/llama-model.h
-index c1b9c0a1..5b23e2ba 100644
---- a/src/llama-model.h
-+++ b/src/llama-model.h
-@@ -9,6 +9,7 @@
- #include "ggml-cpp.h"
- 
- #include 
-+#include 
- 
- // available models
- // TODO: this enum does not follow the enum naming convention
-@@ -62,6 +63,7 @@ enum llm_type {
-     MODEL_40B,
-     MODEL_65B,
-     MODEL_70B,
-+    MODEL_90B,
-     MODEL_236B,
-     MODEL_314B,
-     MODEL_671B,
-@@ -278,6 +280,16 @@ struct llama_layer {
- 
-     struct ggml_tensor * bskcn_tv = nullptr;
- 
-+     // cross attention
-+    struct ggml_tensor * cross_attn_k_norm = nullptr;
-+    struct ggml_tensor * cross_attn_k_proj = nullptr;
-+    struct ggml_tensor * cross_attn_o_proj = nullptr;
-+    struct ggml_tensor * cross_attn_q_norm = nullptr;
-+    struct ggml_tensor * cross_attn_q_proj = nullptr;
-+    struct ggml_tensor * cross_attn_v_proj = nullptr;
-+    struct ggml_tensor * cross_attn_attn_gate = nullptr;
-+    struct ggml_tensor * cross_attn_mlp_gate = nullptr;
-+
-     struct llama_layer_posnet posnet;
- 
-     struct llama_layer_convnext convnext;
-@@ -376,6 +388,45 @@ std::string llama_model_arch_name (const llama_model & model);
- std::string llama_model_type_name (const llama_model & model);
- std::string llama_model_ftype_name(const llama_model & model);
- 
-+template
-+bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
-+    ggml_init_params params = {
-+        /*.mem_size   =*/ ggml_tensor_overhead()*8,
-+        /*.mem_buffer =*/ NULL,
-+        /*.no_alloc   =*/ true,
-+    };
-+
-+    ggml_context_ptr ctx { ggml_init(params) };
-+    if (!ctx) {
-+        throw std::runtime_error("failed to create ggml context");
-+    }
-+
-+    ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
-+    ggml_tensor * op_tensor = fn(ctx.get());
-+    for (int i = 0; i < GGML_MAX_SRC; i++) {
-+        if (op_tensor->src[i] != nullptr) {
-+            op_tensor->src[i]->buffer = buf.get();
-+        }
-+    }
-+
-+    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
-+
-+    return op_supported;
-+}
-+
-+template
-+ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) {
-+    for (const auto & cur : buft_list) {
-+        ggml_backend_dev_t cur_dev = cur.first;
-+        ggml_backend_buffer_type_t cur_buft = cur.second;
-+        if (buft_supported(cur_buft, cur_dev, fn)) {
-+            return cur_buft;
-+        }
-+    }
-+
-+    throw std::runtime_error("no suitable buffer type found");
-+}
-+
- // used by llama_adapter_cvec
- ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
- 
-diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
-index 42974f8f..27def6fd 100644
---- a/src/llama-quant.cpp
-+++ b/src/llama-quant.cpp
-@@ -629,7 +629,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
-         if (llama_model_has_encoder(&model)) {
-             n_attn_layer *= 3;
-         }
--        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
-+        if (qs.n_attention_wv != n_attn_layer) {
-+            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
-+        }
-     }
- 
-     size_t total_size_org = 0;
-diff --git a/src/llama.cpp b/src/llama.cpp
-index 7dec50ae..bac66c24 100644
---- a/src/llama.cpp
-+++ b/src/llama.cpp
-@@ -563,6 +563,52 @@ static bool llm_load_tensors(
+@@ -1398,7 +1412,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+         const int64_t n_embd_head_v = hparams.n_embd_head_v;
+         const int64_t n_ff          = hparams.n_ff();
+         const int64_t n_embd_gqa    = n_embd_v_gqa;
+-        const int64_t n_vocab       = vocab.n_tokens();
++        const int64_t n_vocab       = hparams.n_vocab;
+         const int64_t n_token_types = vocab.n_token_types();
+         const int64_t n_rot         = hparams.n_rot;
+         const int64_t n_expert      = hparams.n_expert;
+@@ -1581,6 +1595,52 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                          }
                      }
                  } break;
 +            case LLM_ARCH_MLLAMA:
 +                {
-+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0);
++                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab+8}, 0);
 +
 +                    // output
 +                    {
-+                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-+                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
++                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
++                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
 +
 +                        // if output is NULL, init from the input tok embed
-+                        if (model.output == NULL) {
-+                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
++                        if (output == NULL) {
++                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
 +                        }
 +                    }
 +
 +                    for (int i = 0; i < n_layer; ++i) {
-+                        auto & layer = model.layers[i];
++                        auto & layer = layers[i];
 +
 +                        if (hparams.cross_attention_layers(i)) {
 +                            layer.cross_attn_k_norm = create_tensor(tn(LLM_TENSOR_CROSS_ATTN_K_NORM,   "weight", i), {128}, 0);
@@ -667,17 +561,72 @@ index 7dec50ae..bac66c24 100644
 +                } break;
              case LLM_ARCH_DECI:
                  {
-                     model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-@@ -2514,7 +2560,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
+                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+@@ -3925,6 +3985,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
  
-         if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
-             model.hparams.n_vocab != model.vocab.id_to_token.size()) {
--            throw std::runtime_error("vocab size mismatch");
-+            LLAMA_LOG_WARN("%s: vocab mismatch %u !- %zu ...\n", __func__, model.hparams.n_vocab, model.vocab.id_to_token.size());
+         // use what we call a normal RoPE, operating on pairs of consecutive head values
+         case LLM_ARCH_LLAMA:
++        case LLM_ARCH_MLLAMA:
+         case LLM_ARCH_DECI:
+         case LLM_ARCH_BAICHUAN:
+         case LLM_ARCH_STARCODER:
+diff --git a/src/llama-model.h b/src/llama-model.h
+index 1afb0024..7cf57587 100644
+--- a/src/llama-model.h
++++ b/src/llama-model.h
+@@ -9,6 +9,7 @@
+ #include 
+ #include 
+ #include 
++#include 
+ 
+ struct llama_model_loader;
+ 
+@@ -63,6 +64,7 @@ enum llm_type {
+     LLM_TYPE_40B,
+     LLM_TYPE_65B,
+     LLM_TYPE_70B,
++    LLM_TYPE_90B,
+     LLM_TYPE_236B,
+     LLM_TYPE_314B,
+     LLM_TYPE_671B,
+@@ -284,6 +286,16 @@ struct llama_layer {
+ 
+     struct ggml_tensor * bskcn_tv = nullptr;
+ 
++    // cross attention
++    struct ggml_tensor * cross_attn_k_norm = nullptr;
++    struct ggml_tensor * cross_attn_k_proj = nullptr;
++    struct ggml_tensor * cross_attn_o_proj = nullptr;
++    struct ggml_tensor * cross_attn_q_norm = nullptr;
++    struct ggml_tensor * cross_attn_q_proj = nullptr;
++    struct ggml_tensor * cross_attn_v_proj = nullptr;
++    struct ggml_tensor * cross_attn_attn_gate = nullptr;
++    struct ggml_tensor * cross_attn_mlp_gate = nullptr;
++
+     struct llama_layer_posnet posnet;
+ 
+     struct llama_layer_convnext convnext;
+diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp
+index fb798265..6eb1da08 100644
+--- a/src/llama-quant.cpp
++++ b/src/llama-quant.cpp
+@@ -632,7 +632,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
+         if (llama_model_has_encoder(&model)) {
+             n_attn_layer *= 3;
          }
+-        GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
++        if (qs.n_attention_wv != n_attn_layer) {
++            LLAMA_LOG_WARN("%s: n_attention_wv is unexpected, expected: %d, found: %d\n", __func__, n_attn_layer, qs.n_attention_wv);
++        }
+     }
  
-         if (params.vocab_only) {
-@@ -2598,6 +2644,21 @@ static struct ggml_tensor * llm_build_inp_embd(
+     size_t total_size_org = 0;
+diff --git a/src/llama.cpp b/src/llama.cpp
+index 6d320ea4..8f7902df 100644
+--- a/src/llama.cpp
++++ b/src/llama.cpp
+@@ -154,6 +154,21 @@ static struct ggml_tensor * llm_build_inp_embd(
      return inpL;
  }
  
@@ -699,7 +648,7 @@ index 7dec50ae..bac66c24 100644
  static void llm_build_kv_store(
          struct ggml_context * ctx,
          const llama_hparams & hparams,
-@@ -3593,6 +3654,7 @@ struct llm_build_context {
+@@ -1157,6 +1172,7 @@ struct llm_build_context {
          lctx.inp_pos_bucket    = nullptr;
          lctx.inp_embd_enc      = nullptr;
          lctx.inp_KQ_mask_cross = nullptr;
@@ -707,12 +656,12 @@ index 7dec50ae..bac66c24 100644
      }
  
      void free() {
-@@ -4074,6 +4136,240 @@ struct llm_build_context {
+@@ -1639,6 +1655,240 @@ struct llm_build_context {
          return gf;
      }
  
-+        struct ggml_cgraph * build_mllama() {
-+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
++    struct ggml_cgraph * build_mllama() {
++        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 +
 +        // mutable variable, needed during the last layer of the computation to skip unused tokens
 +        int32_t n_tokens = this->n_tokens;
@@ -946,9 +895,9 @@ index 7dec50ae..bac66c24 100644
 +    }
 +
      struct ggml_cgraph * build_deci() {
-         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
  
-@@ -10646,6 +10942,10 @@ static struct ggml_cgraph * llama_build_graph(
+@@ -8344,6 +8594,10 @@ static struct ggml_cgraph * llama_build_graph(
              {
                  result = llm.build_llama();
              } break;
@@ -959,16 +908,33 @@ index 7dec50ae..bac66c24 100644
          case LLM_ARCH_DECI:
              {
                  result = llm.build_deci();
-@@ -10971,7 +11271,7 @@ static int llama_decode_internal(
+@@ -8634,7 +8888,7 @@ static int llama_prepare_sbatch(
          n_outputs = 1;
      }
  
 -    lctx.sbatch.from_batch(batch, n_embd,
 +    lctx.sbatch.from_batch(batch, batch.n_embd,
-         /* simple_split */ !kv_self.recurrent,
+         /* simple_split */ !lctx.kv_self.recurrent,
          /* logits_all   */ n_outputs == n_tokens_all);
  
-@@ -11282,7 +11582,7 @@ static int llama_encode_internal(
+@@ -8749,7 +9003,6 @@ static int llama_decode_impl(
+     const llama_batch & batch = batch_allocr.batch;
+ 
+     const auto & model   = lctx.model;
+-    const auto & vocab   = model.vocab;
+     const auto & hparams = model.hparams;
+     const auto & cparams = lctx.cparams;
+ 
+@@ -8760,7 +9013,7 @@ static int llama_decode_impl(
+     llama_kv_slot_restorer kv_slot_restorer(kv_self);
+ 
+     const int64_t n_embd  = hparams.n_embd;
+-    const int64_t n_vocab = vocab.n_tokens();
++    const int64_t n_vocab = hparams.n_vocab;
+ 
+     uint32_t n_outputs = 0;
+     uint32_t n_outputs_prev = 0;
+@@ -9025,7 +9278,7 @@ static int llama_encode_impl(
  
      const int64_t n_embd = hparams.n_embd;
  
@@ -977,7 +943,7 @@ index 7dec50ae..bac66c24 100644
  
      const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens);
  
-@@ -11775,6 +12075,7 @@ struct llama_context_params llama_context_default_params() {
+@@ -9511,6 +9764,7 @@ struct llama_context_params llama_context_default_params() {
          /*.offload_kqv                 =*/ true,
          /*.flash_attn                  =*/ false,
          /*.no_perf                     =*/ true,
diff --git a/llama/patches/0008-add-unpad-operator.patch b/llama/patches/0008-add-unpad-operator.patch
index fd070df9..bfa82de2 100644
--- a/llama/patches/0008-add-unpad-operator.patch
+++ b/llama/patches/0008-add-unpad-operator.patch
@@ -15,10 +15,10 @@ Subject: [PATCH] add unpad operator
  8 files changed, 220 insertions(+), 2 deletions(-)
 
 diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
-index c714fc8c..1bc50fca 100644
+index dd0c6a96..8d269a9c 100644
 --- a/ggml/include/ggml.h
 +++ b/ggml/include/ggml.h
-@@ -499,6 +499,7 @@ extern "C" {
+@@ -487,6 +487,7 @@ extern "C" {
          GGML_OP_UPSCALE, // nearest interpolate
          GGML_OP_PAD,
          GGML_OP_PAD_REFLECT_1D,
@@ -26,7 +26,7 @@ index c714fc8c..1bc50fca 100644
          GGML_OP_ARANGE,
          GGML_OP_TIMESTEP_EMBEDDING,
          GGML_OP_ARGSORT,
-@@ -1735,6 +1736,15 @@ extern "C" {
+@@ -1743,6 +1744,15 @@ extern "C" {
              int                   p0,
              int                   p1);
  
@@ -43,10 +43,10 @@ index c714fc8c..1bc50fca 100644
      // timesteps: [N,]
      // return: [N, dim]
 diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
-index b7fefb9d..b307d554 100644
+index 72325349..2f606d82 100644
 --- a/ggml/src/ggml-cpu/ggml-cpu.c
 +++ b/ggml/src/ggml-cpu/ggml-cpu.c
-@@ -10588,6 +10588,59 @@ static void ggml_compute_forward_pad_reflect_1d(
+@@ -10844,6 +10844,59 @@ static void ggml_compute_forward_pad_reflect_1d(
      }
  }
  
@@ -106,7 +106,7 @@ index b7fefb9d..b307d554 100644
  // ggml_compute_forward_arange
  
  static void ggml_compute_forward_arange_f32(
-@@ -12690,6 +12743,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
+@@ -13137,6 +13190,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
              {
                  ggml_compute_forward_pad_reflect_1d(params, tensor);
              } break;
@@ -117,7 +117,7 @@ index b7fefb9d..b307d554 100644
          case GGML_OP_ARANGE:
              {
                  ggml_compute_forward_arange(params, tensor);
-@@ -13033,6 +13090,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
+@@ -13484,6 +13541,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
          case GGML_OP_UPSCALE:
          case GGML_OP_PAD:
          case GGML_OP_PAD_REFLECT_1D:
@@ -126,10 +126,10 @@ index b7fefb9d..b307d554 100644
          case GGML_OP_TIMESTEP_EMBEDDING:
          case GGML_OP_ARGSORT:
 diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
-index aaa79ea4..9286f866 100644
+index 36165840..1adf08fa 100644
 --- a/ggml/src/ggml-cuda/ggml-cuda.cu
 +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
-@@ -2082,6 +2082,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
+@@ -2198,6 +2198,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
          case GGML_OP_PAD:
              ggml_cuda_op_pad(ctx, dst);
              break;
@@ -139,8 +139,8 @@ index aaa79ea4..9286f866 100644
          case GGML_OP_ARANGE:
              ggml_cuda_op_arange(ctx, dst);
              break;
-@@ -3010,6 +3013,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
-         case GGML_OP_GROUP_NORM:
+@@ -3197,6 +3200,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
+             return ggml_is_contiguous(op->src[0]);
          case GGML_OP_UPSCALE:
          case GGML_OP_PAD:
 +        case GGML_OP_UNPAD:
@@ -148,7 +148,7 @@ index aaa79ea4..9286f866 100644
          case GGML_OP_TIMESTEP_EMBEDDING:
          case GGML_OP_LEAKY_RELU:
 diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu
-index aba539e8..39fd4b16 100644
+index aba539e8..b4b87409 100644
 --- a/ggml/src/ggml-cuda/pad.cu
 +++ b/ggml/src/ggml-cuda/pad.cu
 @@ -47,3 +47,49 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -201,6 +201,7 @@ index aba539e8..39fd4b16 100644
 +        src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
 +        dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
 +}
+\ No newline at end of file
 diff --git a/ggml/src/ggml-cuda/pad.cuh b/ggml/src/ggml-cuda/pad.cuh
 index 8fd386b0..e2ededc3 100644
 --- a/ggml/src/ggml-cuda/pad.cuh
@@ -211,10 +212,10 @@ index 8fd386b0..e2ededc3 100644
  void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 +void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
-index cd8ef741..318addec 100644
+index fd9a4e77..e4c093f9 100644
 --- a/ggml/src/ggml-metal/ggml-metal.m
 +++ b/ggml/src/ggml-metal/ggml-metal.m
-@@ -311,6 +311,7 @@ enum ggml_metal_kernel_type {
+@@ -331,6 +331,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
      GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
      GGML_METAL_KERNEL_TYPE_PAD_F32,
      GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
@@ -222,7 +223,7 @@ index cd8ef741..318addec 100644
      GGML_METAL_KERNEL_TYPE_ARANGE_F32,
      GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
      GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
-@@ -910,6 +911,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
+@@ -946,6 +947,7 @@ @implementation GGMLMetalClass
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,            pad_reflect_1d_f32,             true);
@@ -230,7 +231,7 @@ index cd8ef741..318addec 100644
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32,                    arange_f32,                     true);
          GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
-@@ -1145,6 +1147,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
+@@ -1254,6 +1256,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
          case GGML_OP_UPSCALE:
          case GGML_OP_PAD:
          case GGML_OP_PAD_REFLECT_1D:
@@ -238,7 +239,7 @@ index cd8ef741..318addec 100644
          case GGML_OP_ARANGE:
          case GGML_OP_TIMESTEP_EMBEDDING:
          case GGML_OP_ARGSORT:
-@@ -3348,6 +3351,36 @@ static void ggml_metal_encode_node(
+@@ -3469,6 +3472,36 @@ static void ggml_metal_encode_node(
  
                  const int nth = MIN(1024, ne0);
  
@@ -276,10 +277,10 @@ index cd8ef741..318addec 100644
              } break;
          case GGML_OP_ARANGE:
 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
-index 8ba43904..204c93e6 100644
+index d092a169..f38909d0 100644
 --- a/ggml/src/ggml-metal/ggml-metal.metal
 +++ b/ggml/src/ggml-metal/ggml-metal.metal
-@@ -2944,6 +2944,51 @@ kernel void kernel_pad_reflect_1d_f32(
+@@ -2953,6 +2953,51 @@ kernel void kernel_pad_reflect_1d_f32(
      }
  }
  
@@ -332,10 +333,10 @@ index 8ba43904..204c93e6 100644
      device        char * dst,
      constant   int64_t & ne0,
 diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
-index 2bbe5f48..7ffcd907 100644
+index 7fc06724..635aa299 100644
 --- a/ggml/src/ggml.c
 +++ b/ggml/src/ggml.c
-@@ -954,6 +954,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
+@@ -962,6 +962,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
      "UPSCALE",
      "PAD",
      "PAD_REFLECT_1D",
@@ -343,16 +344,16 @@ index 2bbe5f48..7ffcd907 100644
      "ARANGE",
      "TIMESTEP_EMBEDDING",
      "ARGSORT",
-@@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
+@@ -996,7 +997,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
      "OPT_STEP_ADAMW",
  };
  
--static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
-+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
++static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
  
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
      "none",
-@@ -1050,6 +1051,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+@@ -1059,6 +1060,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
      "upscale(x)",
      "pad(x)",
      "pad_reflect_1d(x)",
@@ -360,16 +361,16 @@ index 2bbe5f48..7ffcd907 100644
      "arange(start, stop, step)",
      "timestep_embedding(timesteps, dim, max_period)",
      "argsort(x)",
-@@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
+@@ -1093,7 +1095,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
      "adamw(x)",
  };
  
--static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
-+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
++static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
  
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
  
-@@ -4214,6 +4216,25 @@ struct ggml_tensor * ggml_pad_reflect_1d(
+@@ -4225,6 +4227,25 @@ struct ggml_tensor * ggml_pad_reflect_1d(
      return result;
  }
  
diff --git a/llama/patches/0009-fix-deepseek-deseret-regex.patch b/llama/patches/0009-fix-deepseek-deseret-regex.patch
index 5c334cfd..715c5206 100644
--- a/llama/patches/0009-fix-deepseek-deseret-regex.patch
+++ b/llama/patches/0009-fix-deepseek-deseret-regex.patch
@@ -11,10 +11,10 @@ the characters
  2 files changed, 23 insertions(+), 1 deletion(-)
 
 diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
-index 3fcfcaa3..8f44705a 100644
+index a4eee9b8..1ca827eb 100644
 --- a/src/llama-vocab.cpp
 +++ b/src/llama-vocab.cpp
-@@ -375,7 +375,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
+@@ -295,7 +295,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
              case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
                  regex_exprs = {
                      "[\r\n]",
@@ -24,7 +24,7 @@ index 3fcfcaa3..8f44705a 100644
                      "\\s+$",
                      "[一-龥ࠀ-一가-퟿]+",
 diff --git a/src/unicode.cpp b/src/unicode.cpp
-index 7aca6544..6155da80 100644
+index e63bb4ab..9dd53b9a 100644
 --- a/src/unicode.cpp
 +++ b/src/unicode.cpp
 @@ -2,6 +2,11 @@
@@ -39,7 +39,7 @@ index 7aca6544..6155da80 100644
  #include "unicode.h"
  #include "unicode-data.h"
  
-@@ -201,6 +206,22 @@ static std::unordered_map unicode_utf8_to_byte_map() {
+@@ -200,6 +205,22 @@ static std::unordered_map unicode_utf8_to_byte_map() {
  }
  
  static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
@@ -62,7 +62,7 @@ index 7aca6544..6155da80 100644
  #if defined(__clang__)
      // disable C++17 deprecation warning for std::codecvt_utf8
  #    pragma clang diagnostic push
-@@ -214,6 +235,7 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
+@@ -213,6 +234,7 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
  #endif
  
      return conv.from_bytes(s);
diff --git a/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch b/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch
index 33b504ec..1e930fb2 100644
--- a/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch
+++ b/llama/patches/0010-Maintain-ordering-for-rules-for-grammar.patch
@@ -8,11 +8,11 @@ Subject: [PATCH] Maintain ordering for rules for grammar
  1 file changed, 1 insertion(+), 1 deletion(-)
 
 diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
-index dadc18c8..2a8dbd22 100644
+index 3ebcc3d9..30c28808 100644
 --- a/common/json-schema-to-grammar.cpp
 +++ b/common/json-schema-to-grammar.cpp
-@@ -391,7 +391,7 @@ class SchemaConverter {
- private:
+@@ -346,7 +346,7 @@ private:
+     friend std::string build_grammar(const std::function & cb, const common_grammar_options & options);
      std::function _fetch_json;
      bool _dotall;
 -    std::map _rules;
diff --git a/llama/patches/0011-fix-missing-arg-in-static-assert-on-windows.patch b/llama/patches/0011-fix-missing-arg-in-static-assert-on-windows.patch
deleted file mode 100644
index 8c43ad3d..00000000
--- a/llama/patches/0011-fix-missing-arg-in-static-assert-on-windows.patch
+++ /dev/null
@@ -1,22 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Sat, 14 Dec 2024 12:54:00 -0800
-Subject: [PATCH] fix missing arg in static assert on windows
-
----
- ggml/src/ggml-cuda/concat.cu | 2 +-
- 1 file changed, 1 insertion(+), 1 deletion(-)
-
-diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu
-index 2f42b8a9..5eb9f08d 100644
---- a/ggml/src/ggml-cuda/concat.cu
-+++ b/ggml/src/ggml-cuda/concat.cu
-@@ -124,7 +124,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
-           uint64_t   nb1,
-           uint64_t   nb2,
-           uint64_t   nb3){
--    static_assert(dim >= 0 && dim <= 3);
-+    static_assert(dim >= 0 && dim <= 3, "dim must be between 0 and 3");
- 
-     const int64_t i3 = blockIdx.z;
-     const int64_t i2 = blockIdx.y;
diff --git a/llama/patches/0012-llama-Ensure-KV-cache-is-fully-defragmented.patch b/llama/patches/0011-llama-Ensure-KV-cache-is-fully-defragmented.patch
similarity index 84%
rename from llama/patches/0012-llama-Ensure-KV-cache-is-fully-defragmented.patch
rename to llama/patches/0011-llama-Ensure-KV-cache-is-fully-defragmented.patch
index 3ef51f4e..ff057539 100644
--- a/llama/patches/0012-llama-Ensure-KV-cache-is-fully-defragmented.patch
+++ b/llama/patches/0011-llama-Ensure-KV-cache-is-fully-defragmented.patch
@@ -19,10 +19,10 @@ multiple batches of processing until everything is complete.
  1 file changed, 46 insertions(+), 53 deletions(-)
 
 diff --git a/src/llama.cpp b/src/llama.cpp
-index bac66c24..c95da45d 100644
+index 8f7902df..01854fce 100644
 --- a/src/llama.cpp
 +++ b/src/llama.cpp
-@@ -3536,6 +3536,13 @@ static struct ggml_tensor * llm_build_rwkv6_channel_mix(
+@@ -1054,6 +1054,13 @@ static struct ggml_tensor * llm_build_rwkv6_channel_mix(
      return ggml_mul(ctx, r, llm_build_lora_mm(lctx, ctx, layer->channel_mix_value, k));
  }
  
@@ -36,13 +36,13 @@ index bac66c24..c95da45d 100644
  struct llm_build_context {
      const llama_model    & model;
            llama_context  & lctx;
-@@ -3712,35 +3719,23 @@ struct llm_build_context {
+@@ -1230,35 +1237,23 @@ struct llm_build_context {
          return gf;
      }
  
 -    struct ggml_cgraph * build_defrag(const std::vector & ids) {
 +    struct ggml_cgraph * build_defrag(const std::vector & moves) {
-         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+         struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
  
 -        for (uint32_t i = 0; i < ids.size(); ++i) {
 -            const uint32_t id = ids[i];
@@ -78,7 +78,7 @@ index bac66c24..c95da45d 100644
  
                  ggml_tensor * view_v_src;
                  ggml_tensor * view_v_dst;
-@@ -3748,31 +3743,29 @@ struct llm_build_context {
+@@ -1266,31 +1261,29 @@ struct llm_build_context {
                  if (flash_attn) {
                      // NOTE: the V cache is not transposed when using flash attention
                      view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
@@ -118,7 +118,7 @@ index bac66c24..c95da45d 100644
          }
  
          //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
-@@ -10856,7 +10849,7 @@ struct llm_build_context {
+@@ -8508,7 +8501,7 @@ struct llm_build_context {
      }
  };
  
@@ -127,7 +127,7 @@ index bac66c24..c95da45d 100644
      llama_ubatch dummy = {};
      dummy.equal_seqs = true;
  
-@@ -10866,7 +10859,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
+@@ -8518,7 +8511,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
  
      llm.init();
  
@@ -136,21 +136,21 @@ index bac66c24..c95da45d 100644
  
      llm.free();
  
-@@ -11329,7 +11322,12 @@ static int llama_decode_internal(
-                 kv_self.head = 0;
-             }
+@@ -8956,7 +8949,12 @@ static int llama_prepare_ubatch(
+             kv_self.head = 0;
+         }
  
--            const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-+            auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
-+            if (!slot) {
-+                llama_kv_cache_defrag(kv_self);
-+                llama_kv_cache_update(&lctx);
-+                slot = llama_kv_cache_find_slot(kv_self, ubatch);
-+            }
-             if (!slot) {
-                 return 1;
-             }
-@@ -11735,8 +11733,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+-        const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
++        auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
++        if (!slot) {
++            llama_kv_cache_defrag(kv_self);
++            llama_kv_cache_update(&lctx);
++            slot = llama_kv_cache_find_slot(kv_self, ubatch);
++        }
+         if (!slot) {
+             return 1;
+         }
+@@ -9431,8 +9429,8 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
  
      //const int64_t t_start = ggml_time_us();
  
@@ -161,7 +161,7 @@ index bac66c24..c95da45d 100644
  
      // each move requires 6*n_layer tensors (see build_defrag)
      //   - source view, destination view, copy operation
-@@ -11800,19 +11798,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+@@ -9496,19 +9494,11 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
          // are we moving a continuous block of memory?
          bool cont = false;
  
@@ -181,7 +181,7 @@ index bac66c24..c95da45d 100644
                  cont = false;
                  continue;
              }
-@@ -11828,8 +11818,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+@@ -9524,8 +9514,10 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
              kv_self.head = n_used;
  
              if (!cont) {
@@ -193,7 +193,7 @@ index bac66c24..c95da45d 100644
              }
  
              nf++;
-@@ -11839,22 +11831,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+@@ -9535,22 +9527,16 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
              }
          }
  
@@ -218,7 +218,7 @@ index bac66c24..c95da45d 100644
  
  #if 0
      // CPU defrag
-@@ -11929,11 +11915,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
+@@ -9625,11 +9611,18 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
  #else
      // ggml_graph defrag
  
diff --git a/llama/patches/0013-use-dynamic-backend-loading-for-clip.patch b/llama/patches/0012-use-dynamic-backend-loading-for-clip.patch
similarity index 94%
rename from llama/patches/0013-use-dynamic-backend-loading-for-clip.patch
rename to llama/patches/0012-use-dynamic-backend-loading-for-clip.patch
index e283a857..e2857169 100644
--- a/llama/patches/0013-use-dynamic-backend-loading-for-clip.patch
+++ b/llama/patches/0012-use-dynamic-backend-loading-for-clip.patch
@@ -8,12 +8,12 @@ Subject: [PATCH] use dynamic backend loading for clip
  1 file changed, 27 insertions(+), 47 deletions(-)
 
 diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
-index b3c1829f..86b91d5c 100644
+index 205af1eb..560021c7 100644
 --- a/examples/llava/clip.cpp
 +++ b/examples/llava/clip.cpp
-@@ -8,25 +8,25 @@
- #include "ggml-alloc.h"
+@@ -9,25 +9,25 @@
  #include "ggml-backend.h"
+ #include "gguf.h"
  
 -//#ifdef GGML_USE_CUDA
 -//#include "ggml-cuda.h"
@@ -56,7 +56,7 @@ index b3c1829f..86b91d5c 100644
  
  #define STB_IMAGE_IMPLEMENTATION
  #include "stb_image.h"
-@@ -1235,35 +1235,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
+@@ -1309,35 +1309,15 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
          }
      }
  
diff --git a/llama/patches/0014-sort-devices-by-score.patch b/llama/patches/0013-sort-devices-by-score.patch
similarity index 98%
rename from llama/patches/0014-sort-devices-by-score.patch
rename to llama/patches/0013-sort-devices-by-score.patch
index d57a4366..7640a8db 100644
--- a/llama/patches/0014-sort-devices-by-score.patch
+++ b/llama/patches/0013-sort-devices-by-score.patch
@@ -8,7 +8,7 @@ Subject: [PATCH] sort devices by score
  1 file changed, 13 insertions(+), 8 deletions(-)
 
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 899d16f2..135f7df0 100644
+index 95036ef8..98d5e14d 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
 @@ -150,7 +150,7 @@ struct ggml_backend_reg_entry {
diff --git a/llama/patches/0015-add-phony-target-ggml-cpu-for-all-cpu-variants.patch b/llama/patches/0014-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
similarity index 86%
rename from llama/patches/0015-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
rename to llama/patches/0014-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
index e68950a5..f263ece1 100644
--- a/llama/patches/0015-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
+++ b/llama/patches/0014-add-phony-target-ggml-cpu-for-all-cpu-variants.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants
  1 file changed, 2 insertions(+)
 
 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
-index 84101c32..72b488dd 100644
+index 0002ac18..0a8d1092 100644
 --- a/ggml/src/CMakeLists.txt
 +++ b/ggml/src/CMakeLists.txt
-@@ -278,6 +278,7 @@ function(ggml_add_cpu_backend_variant tag_name)
+@@ -297,6 +297,7 @@ function(ggml_add_cpu_backend_variant tag_name)
      endforeach()
  
      ggml_add_cpu_backend_variant_impl(${tag_name})
@@ -19,7 +19,7 @@ index 84101c32..72b488dd 100644
  endfunction()
  
  ggml_add_backend(CPU)
-@@ -286,6 +287,7 @@ if (GGML_CPU_ALL_VARIANTS)
+@@ -305,6 +306,7 @@ if (GGML_CPU_ALL_VARIANTS)
      if (NOT GGML_BACKEND_DL)
          message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
      endif()
diff --git a/llama/patches/0017-try-catch-backend-load.patch b/llama/patches/0015-try-catch-backend-load.patch
similarity index 99%
rename from llama/patches/0017-try-catch-backend-load.patch
rename to llama/patches/0015-try-catch-backend-load.patch
index e7f71c7c..9aea6183 100644
--- a/llama/patches/0017-try-catch-backend-load.patch
+++ b/llama/patches/0015-try-catch-backend-load.patch
@@ -8,7 +8,7 @@ Subject: [PATCH] try/catch backend load
  1 file changed, 23 insertions(+), 22 deletions(-)
 
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 135f7df0..84b21dd8 100644
+index 98d5e14d..1c19129a 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
 @@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
diff --git a/llama/patches/0016-remove-sgemm-global-variables.patch b/llama/patches/0016-remove-sgemm-global-variables.patch
deleted file mode 100644
index 31a59aea..00000000
--- a/llama/patches/0016-remove-sgemm-global-variables.patch
+++ /dev/null
@@ -1,55 +0,0 @@
-From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
-From: jmorganca 
-Date: Sun, 9 Feb 2025 17:22:15 -0800
-Subject: [PATCH] remove sgemm global variables
-
-removes the 'iq4nlt' global variable in sgemm.cpp that causes
-a runtime crash when calling dlopen on ggml-cpu libraries as
-its initialization depends on AVX instructions the host machine
-may not have
----
- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 17 +++++++++--------
- 1 file changed, 9 insertions(+), 8 deletions(-)
-
-diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
-index 8fce576c..3f260ce5 100644
---- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp
-+++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
-@@ -279,14 +279,6 @@ template <> inline __m256bh load(const float *p) {
- }
- #endif
- 
--////////////////////////////////////////////////////////////////////////////////////////////////////
--// CONSTANTS
--
--#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
--static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
--static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
--#endif
--
- ////////////////////////////////////////////////////////////////////////////////////////////////////
- // FLOATING POINT MATRIX MULTIPLICATION
- 
-@@ -613,6 +605,14 @@ class tinyBLAS_Q0_AVX {
-                     TC *C, int64_t ldc,
-                     int ith, int nth)
-         : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
-+        const int8_t kvalues_iq4nl[16] = {
-+            -127, -104, -83, -65,
-+            -49,  -35,  -22, -10,
-+              1,   13,   25,  38,
-+             53,   69,   89, 113
-+        };
-+
-+        iq4nlt = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
-     }
- 
-     void matmul(int64_t m, int64_t n) {
-@@ -1037,6 +1037,7 @@ class tinyBLAS_Q0_AVX {
-     const int64_t ldc;
-     const int ith;
-     const int nth;
-+    __m128i iq4nlt;
- };
- #endif // __AVX__
- 
diff --git a/llama/patches/0018-use-std-filesystem-path-instead-of-wstring.patch b/llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch
similarity index 99%
rename from llama/patches/0018-use-std-filesystem-path-instead-of-wstring.patch
rename to llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch
index 95144fb4..d60066c1 100644
--- a/llama/patches/0018-use-std-filesystem-path-instead-of-wstring.patch
+++ b/llama/patches/0016-use-std-filesystem-path-instead-of-wstring.patch
@@ -8,7 +8,7 @@ Subject: [PATCH] use std::filesystem::path instead of wstring
  1 file changed, 58 insertions(+), 86 deletions(-)
 
 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
-index 84b21dd8..e35a6936 100644
+index 1c19129a..c854e6bb 100644
 --- a/ggml/src/ggml-backend-reg.cpp
 +++ b/ggml/src/ggml-backend-reg.cpp
 @@ -66,26 +66,6 @@
diff --git a/llama/patches/0019-remove-amx.patch b/llama/patches/0017-remove-amx.patch
similarity index 89%
rename from llama/patches/0019-remove-amx.patch
rename to llama/patches/0017-remove-amx.patch
index 5428ee64..234d51cc 100644
--- a/llama/patches/0019-remove-amx.patch
+++ b/llama/patches/0017-remove-amx.patch
@@ -8,10 +8,10 @@ Subject: [PATCH] remove amx
  1 file changed, 4 deletions(-)
 
 diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
-index 72b488dd..50828717 100644
+index 0a8d1092..4564df91 100644
 --- a/ggml/src/CMakeLists.txt
 +++ b/ggml/src/CMakeLists.txt
-@@ -293,10 +293,6 @@ if (GGML_CPU_ALL_VARIANTS)
+@@ -312,10 +312,6 @@ if (GGML_CPU_ALL_VARIANTS)
      ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 FMA AVX512)
      ggml_add_cpu_backend_variant(icelake        AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
      ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 FMA AVX_VNNI)
@@ -19,6 +19,6 @@ index 72b488dd..50828717 100644
 -        # MSVC doesn't support AMX
 -        ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
 -    endif()
- else ()
+ elseif (GGML_CPU)
      ggml_add_cpu_backend_variant_impl("")
  endif()
diff --git a/llama/patches/0018-fix-clip-compiler-error.patch b/llama/patches/0018-fix-clip-compiler-error.patch
new file mode 100644
index 00000000..ef6e247b
--- /dev/null
+++ b/llama/patches/0018-fix-clip-compiler-error.patch
@@ -0,0 +1,36 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca 
+Date: Tue, 25 Feb 2025 19:14:51 -0800
+Subject: [PATCH] fix-clip-compiler-error
+
+---
+ examples/llava/clip.cpp | 2 +-
+ examples/llava/clip.h   | 2 +-
+ 2 files changed, 2 insertions(+), 2 deletions(-)
+
+diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
+index 560021c7..54265beb 100644
+--- a/examples/llava/clip.cpp
++++ b/examples/llava/clip.cpp
+@@ -1788,7 +1788,7 @@ void clip_image_f32_batch_free(struct clip_image_f32_batch  * batch) {
+     }
+ }
+ 
+-void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img) {
++void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img) {
+     img->nx = nx;
+     img->ny = ny;
+     img->buf.resize(3 * nx * ny);
+diff --git a/examples/llava/clip.h b/examples/llava/clip.h
+index ce6f6194..f9f80d7d 100644
+--- a/examples/llava/clip.h
++++ b/examples/llava/clip.h
+@@ -75,7 +75,7 @@ CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch  * batch);
+ CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
+ 
+ /** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */
+-CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, clip_image_u8 * img);
++CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
+ 
+ CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
+ 
diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h
index 7221a083..fc9571c8 100644
--- a/ml/backend/ggml/ggml/include/ggml-backend.h
+++ b/ml/backend/ggml/ggml/include/ggml-backend.h
@@ -203,6 +203,8 @@ extern "C" {
     // Backend registry
     //
 
+    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
+
     // Backend (reg) enumeration
     GGML_API size_t             ggml_backend_reg_count(void);
     GGML_API ggml_backend_reg_t ggml_backend_reg_get(size_t index);
diff --git a/ml/backend/ggml/ggml/include/ggml-cpp.h b/ml/backend/ggml/ggml/include/ggml-cpp.h
index 219361af..a12342c2 100644
--- a/ml/backend/ggml/ggml/include/ggml-cpp.h
+++ b/ml/backend/ggml/ggml/include/ggml-cpp.h
@@ -7,6 +7,7 @@
 #include "ggml.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
+#include "gguf.h"
 #include 
 
 // Smart pointers for ggml types
diff --git a/ml/backend/ggml/ggml/include/ggml-cpu.h b/ml/backend/ggml/ggml/include/ggml-cpu.h
index 3aa71bad..b48cc560 100644
--- a/ml/backend/ggml/ggml/include/ggml-cpu.h
+++ b/ml/backend/ggml/ggml/include/ggml-cpu.h
@@ -8,7 +8,7 @@ extern "C" {
 #endif
 
     // the compute plan that needs to be prepared for ggml_graph_compute()
-    // since https://github.com/ggerganov/ggml/issues/287
+    // since https://github.com/ggml-org/ggml/issues/287
     struct ggml_cplan {
         size_t    work_size; // size of work buffer, calculated by `ggml_graph_plan()`
         uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
@@ -95,9 +95,11 @@ extern "C" {
     GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
     GGML_BACKEND_API int ggml_cpu_has_sve        (void);
     GGML_BACKEND_API int ggml_cpu_get_sve_cnt    (void);  // sve vector length in bytes
+    GGML_BACKEND_API int ggml_cpu_has_sme        (void);
     // other
     GGML_BACKEND_API int ggml_cpu_has_riscv_v    (void);
     GGML_BACKEND_API int ggml_cpu_has_vsx        (void);
+    GGML_BACKEND_API int ggml_cpu_has_vxe        (void);
     GGML_BACKEND_API int ggml_cpu_has_wasm_simd  (void);
     GGML_BACKEND_API int ggml_cpu_has_llamafile  (void);
 
diff --git a/ml/backend/ggml/ggml/include/ggml-metal.h b/ml/backend/ggml/ggml/include/ggml-metal.h
index 669c1f84..a6106944 100644
--- a/ml/backend/ggml/ggml/include/ggml-metal.h
+++ b/ml/backend/ggml/ggml/include/ggml-metal.h
@@ -45,7 +45,7 @@ GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
 
 GGML_DEPRECATED(
         GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
-        "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
+        "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
 
 GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
 
diff --git a/ml/backend/ggml/ggml/include/ggml-vulkan.h b/ml/backend/ggml/ggml/include/ggml-vulkan.h
index 53cdba07..ed5ea5f7 100644
--- a/ml/backend/ggml/ggml/include/ggml-vulkan.h
+++ b/ml/backend/ggml/ggml/include/ggml-vulkan.h
@@ -10,8 +10,6 @@ extern "C" {
 #define GGML_VK_NAME "Vulkan"
 #define GGML_VK_MAX_DEVICES 16
 
-GGML_BACKEND_API void ggml_vk_instance_init(void);
-
 // backend API
 GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
 
diff --git a/ml/backend/ggml/ggml/include/ggml.h b/ml/backend/ggml/ggml/include/ggml.h
index 1bc50fca..8d269a9c 100644
--- a/ml/backend/ggml/ggml/include/ggml.h
+++ b/ml/backend/ggml/ggml/include/ggml.h
@@ -198,7 +198,7 @@
 
 #ifndef __GNUC__
 #    define GGML_ATTRIBUTE_FORMAT(...)
-#elif defined(__MINGW32__)
+#elif defined(__MINGW32__) && !defined(__clang__)
 #    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
 #else
 #    define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
@@ -241,12 +241,6 @@
 #define GGML_ROPE_TYPE_MROPE  8
 #define GGML_ROPE_TYPE_VISION 24
 
-#define GGUF_MAGIC "GGUF"
-
-#define GGUF_VERSION 3
-
-#define GGUF_DEFAULT_ALIGNMENT 32
-
 #define GGML_UNUSED(x) (void)(x)
 
 #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
@@ -403,12 +397,6 @@ extern "C" {
         GGML_PREC_F32,
     };
 
-    enum ggml_backend_type {
-        GGML_BACKEND_TYPE_CPU = 0,
-        GGML_BACKEND_TYPE_GPU = 10,
-        GGML_BACKEND_TYPE_GPU_SPLIT = 20,
-    };
-
     // model file types
     enum ggml_ftype {
         GGML_FTYPE_UNKNOWN        = -1,
@@ -514,6 +502,7 @@ extern "C" {
         GGML_OP_GET_REL_POS,
         GGML_OP_ADD_REL_POS,
         GGML_OP_RWKV_WKV6,
+        GGML_OP_GATED_LINEAR_ATTN,
 
         GGML_OP_UNARY,
 
@@ -588,8 +577,6 @@ extern "C" {
     struct ggml_tensor {
         enum ggml_type type;
 
-        GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
-
         struct ggml_backend_buffer * buffer;
 
         int64_t ne[GGML_MAX_DIMS]; // number of elements
@@ -1398,16 +1385,20 @@ extern "C" {
             float                 scale,
             float                 max_bias);
 
-    GGML_API struct ggml_tensor * ggml_soft_max_back(
+    GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            float                 scale,
+            float                 max_bias);
 
     // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
+    GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
-            struct ggml_tensor  * b);
+            struct ggml_tensor  * b,
+            float                 scale,
+            float                 max_bias);
 
     // rotary position embedding
     // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
@@ -1514,7 +1505,7 @@ extern "C" {
 
     // rotary position embedding backward, i.e compute dx from dy
     // a - dy
-    GGML_API struct ggml_tensor * ggml_rope_back(
+    GGML_API struct ggml_tensor * ggml_rope_ext_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a, // gradients of ggml_rope result
             struct ggml_tensor  * b, // positions
@@ -1529,6 +1520,23 @@ extern "C" {
             float                 beta_fast,
             float                 beta_slow);
 
+    GGML_API struct ggml_tensor * ggml_rope_multi_back(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * c,
+            int                   n_dims,
+            int                   sections[4],
+            int                   mode,
+            int                   n_ctx_orig,
+            float                 freq_base,
+            float                 freq_scale,
+            float                 ext_factor,
+            float                 attn_factor,
+            float                 beta_fast,
+            float                 beta_slow);
+
+
     // clamp
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_clamp(
@@ -1777,7 +1785,7 @@ extern "C" {
             struct ggml_tensor  * a,
             int                   k);
 
-#define GGML_KQ_MASK_PAD 32
+#define GGML_KQ_MASK_PAD 64
 
     // q:    [n_embd, n_batch,     n_head,    1]
     // k:    [n_embd, n_kv,        n_head_kv, 1]
@@ -1883,6 +1891,15 @@ extern "C" {
             struct ggml_tensor  * td,
             struct ggml_tensor  * state);
 
+    GGML_API struct ggml_tensor * ggml_gated_linear_attn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * g,
+            struct ggml_tensor  * state,
+            float scale);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -2121,132 +2138,6 @@ extern "C" {
                    int64_t   n_per_row,
                const float * imatrix);
 
-    //
-    // gguf
-    //
-
-    enum gguf_type {
-        GGUF_TYPE_UINT8   = 0,
-        GGUF_TYPE_INT8    = 1,
-        GGUF_TYPE_UINT16  = 2,
-        GGUF_TYPE_INT16   = 3,
-        GGUF_TYPE_UINT32  = 4,
-        GGUF_TYPE_INT32   = 5,
-        GGUF_TYPE_FLOAT32 = 6,
-        GGUF_TYPE_BOOL    = 7,
-        GGUF_TYPE_STRING  = 8,
-        GGUF_TYPE_ARRAY   = 9,
-        GGUF_TYPE_UINT64  = 10,
-        GGUF_TYPE_INT64   = 11,
-        GGUF_TYPE_FLOAT64 = 12,
-        GGUF_TYPE_COUNT,       // marks the end of the enum
-    };
-
-    struct gguf_context;
-
-    struct gguf_init_params {
-        bool no_alloc;
-
-        // if not NULL, create a ggml_context and allocate the tensor data in it
-        struct ggml_context ** ctx;
-    };
-
-    GGML_API struct gguf_context * gguf_init_empty(void);
-    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
-    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
-
-    GGML_API void gguf_free(struct gguf_context * ctx);
-
-    GGML_API const char * gguf_type_name(enum gguf_type type);
-
-    GGML_API int    gguf_get_version    (const struct gguf_context * ctx);
-    GGML_API size_t gguf_get_alignment  (const struct gguf_context * ctx);
-    GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
-    GGML_API void * gguf_get_data       (const struct gguf_context * ctx);
-
-    GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
-    GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
-    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
-
-    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
-    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
-
-    // will abort if the wrong type is used for the key
-    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int key_id);
-    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int key_id);
-    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
-    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
-    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
-    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
-    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
-    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
-    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
-    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
-    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
-    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
-    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
-    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
-    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
-    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
-
-    GGML_API int            gguf_get_n_tensors    (const struct gguf_context * ctx);
-    GGML_API int            gguf_find_tensor      (const struct gguf_context * ctx, const char * name);
-    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
-    GGML_API char *         gguf_get_tensor_name  (const struct gguf_context * ctx, int i);
-    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int i);
-
-    // removes key if it exists
-    GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key);
-
-    // overrides existing values or adds a new one
-    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t  val);
-    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t   val);
-    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
-    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t  val);
-    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
-    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t  val);
-    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float    val);
-    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
-    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t  val);
-    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double   val);
-    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool     val);
-    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
-    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
-    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
-
-    // set or add KV pairs from another context
-    GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
-
-    // manage tensor info
-    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
-    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
-    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
-
-    // writing gguf files can be done in 2 ways:
-    //
-    // - write the entire gguf_context to a binary file in a single pass:
-    //
-    //   gguf_write_to_file(ctx, fname);
-    //
-    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
-    //
-    //   FILE * f = fopen(fname, "wb");
-    //   fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
-    //   fwrite(f, ...);
-    //   void * data = gguf_meta_get_meta_data(ctx);
-    //   fseek(f, 0, SEEK_SET);
-    //   fwrite(f, data, gguf_get_meta_size(ctx));
-    //   free(data);
-    //   fclose(f);
-    //
-
-    // write the entire context to a binary file
-    GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
-
-    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
-    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
-    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
-
 #ifdef __cplusplus
     // restrict not standard in C++
 #    if defined(__GNUC__)
diff --git a/ml/backend/ggml/ggml/include/gguf.h b/ml/backend/ggml/ggml/include/gguf.h
new file mode 100644
index 00000000..79ee2020
--- /dev/null
+++ b/ml/backend/ggml/ggml/include/gguf.h
@@ -0,0 +1,202 @@
+// This file contains functionality related to "GGUF" files, the binary file format used by ggml.
+// GGUF files have the following structure:
+//
+// 1. File magic "GGUF" (4 bytes).
+// 2. File version (uint32_t).
+// 3. Number of ggml tensors in file (int64_t).
+// 4. Number of key-value-pairs in file (int64_t).
+// 5. For each KV pair:
+//   1. The key (string).
+//   2. The value type (gguf_type).
+//   3a. If the value type is GGUF_TYPE_ARRAY:
+//     1. The type of the array (gguf_type).
+//     2. The number of elements in the array (uint64_t).
+//     3. The binary representation of each element in the array.
+//   3b. Otherwise:
+//     1. The binary representation of the value.
+// 6. For each ggml tensor:
+//   1. The tensor name (string).
+//   2. The number of dimensions of the tensor (uint32_t).
+//   3. For each dimension:
+//     1. The size of the tensor in the dimension (int64_t).
+//   4. The tensor data type (ggml_type).
+//   5. The tensor data offset in the tensor data binary blob (uint64_t).
+// 7. The tensor data binary blob (optional, aligned).
+//
+// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator.
+// All enums are stored as int32_t.
+// All bool values are stored as int8_t.
+// If the special key "general.alignment" (uint32_t) is defined it is used for alignment,
+//   otherwise GGUF_DEFAULT_ALIGNMENT is used.
+//
+// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
+
+#pragma once
+
+#include "ggml.h"
+
+#include 
+#include 
+
+#define GGUF_MAGIC   "GGUF"
+#define GGUF_VERSION 3
+
+#define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment"
+
+#define GGUF_DEFAULT_ALIGNMENT 32
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    // types that can be stored as GGUF KV data
+    enum gguf_type {
+        GGUF_TYPE_UINT8   = 0,
+        GGUF_TYPE_INT8    = 1,
+        GGUF_TYPE_UINT16  = 2,
+        GGUF_TYPE_INT16   = 3,
+        GGUF_TYPE_UINT32  = 4,
+        GGUF_TYPE_INT32   = 5,
+        GGUF_TYPE_FLOAT32 = 6,
+        GGUF_TYPE_BOOL    = 7,
+        GGUF_TYPE_STRING  = 8,
+        GGUF_TYPE_ARRAY   = 9,
+        GGUF_TYPE_UINT64  = 10,
+        GGUF_TYPE_INT64   = 11,
+        GGUF_TYPE_FLOAT64 = 12,
+        GGUF_TYPE_COUNT,       // marks the end of the enum
+    };
+
+    struct gguf_context;
+
+    struct gguf_init_params {
+        bool no_alloc;
+
+        // if not NULL, create a ggml_context and allocate the tensor data in it
+        struct ggml_context ** ctx;
+    };
+
+    GGML_API struct gguf_context * gguf_init_empty(void);
+    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
+    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
+
+    GGML_API void gguf_free(struct gguf_context * ctx);
+
+    GGML_API const char * gguf_type_name(enum gguf_type type);
+
+    GGML_API uint32_t gguf_get_version    (const struct gguf_context * ctx);
+    GGML_API size_t   gguf_get_alignment  (const struct gguf_context * ctx);
+    GGML_API size_t   gguf_get_data_offset(const struct gguf_context * ctx);
+
+    GGML_API int64_t      gguf_get_n_kv(const struct gguf_context * ctx);
+    GGML_API int64_t      gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found
+    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id);
+
+    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id);
+
+    // will abort if the wrong type is used for the key
+    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id);
+    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id);
+    GGML_API size_t       gguf_get_arr_n   (const struct gguf_context * ctx, int64_t key_id);
+
+    // get raw pointer to the first element of the array with the given key_id
+    // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
+    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
+
+    // get ith C string from array with given key_id
+    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
+
+    GGML_API int64_t        gguf_get_n_tensors    (const struct gguf_context * ctx);
+    GGML_API int64_t        gguf_find_tensor      (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found
+    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API const char *   gguf_get_tensor_name  (const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API size_t         gguf_get_tensor_size  (const struct gguf_context * ctx, int64_t tensor_id);
+
+    // removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist)
+    GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key);
+
+    // overrides an existing KV pair or adds a new one, the new KV pair is always at the back
+    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t      val);
+    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t       val);
+    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t     val);
+    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t      val);
+    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t     val);
+    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t      val);
+    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float        val);
+    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t     val);
+    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t      val);
+    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double       val);
+    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool         val);
+    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
+
+    // creates a new array with n elements of the given type and copies the corresponding number of bytes from data
+    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n);
+
+    // creates a new array with n strings and copies the corresponding strings from data
+    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n);
+
+    // set or add KV pairs from another context
+    GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src);
+
+    // add tensor to GGUF context, tensor name must be unique
+    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
+
+    // after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated
+    //   in such a way that the tensor data remains as one contiguous block (except for padding)
+    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
+
+    // assumes that at least gguf_get_tensor_size bytes can be read from data
+    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data);
+
+    // writing gguf files can be done in 3 ways:
+    //
+    // - write the entire gguf_context to a binary file in a single pass:
+    //
+    //   gguf_write_to_file(ctx, fname, /*only_meta =*/ false);
+    //
+    // - write only the meta data to a file, then re-open the file and append the tensor data:
+    //
+    //   gguf_write_to_file(ctx, fname, /*only_meta =*/ true);
+    //   FILE * f = fopen(fname, "ab");
+    //   fwrite(f, ...); // write tensor data
+    //   fclose(f);
+    //
+    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
+    //
+    //   FILE * f = fopen(fname, "wb");
+    //   const size_t size_meta = gguf_get_meta_size(ctx);
+    //   fseek(f, size_meta, SEEK_SET);
+    //   fwrite(f, ...); // write tensor data
+    //   void * data = malloc(size_meta);
+    //   gguf_get_meta_data(ctx, data);
+    //   rewind(f);
+    //   fwrite(data, 1, data, f);
+    //   free(data);
+    //   fclose(f);
+    //
+
+    // write the entire context to a binary file
+    GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
+
+    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
+    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
+
+    // writes the meta data to pointer "data"
+    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt
index 50828717..4564df91 100644
--- a/ml/backend/ggml/ggml/src/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/CMakeLists.txt
@@ -93,12 +93,18 @@ endif()
 
 if (GGML_CCACHE)
     find_program(GGML_CCACHE_FOUND ccache)
+    find_program(GGML_SCCACHE_FOUND sccache)
 
-    if (GGML_CCACHE_FOUND)
+    if (GGML_CCACHE_FOUND OR GGML_SCCACHE_FOUND)
+        if(GGML_CCACHE_FOUND)
+            set(GGML_CCACHE_VARIANT ccache)
+        else()
+            set(GGML_CCACHE_VARIANT sccache)
+        endif()
         # TODO: should not be set globally
-        set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
+        set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${GGML_CCACHE_VARIANT}")
         set(ENV{CCACHE_SLOPPINESS} time_macros)
-        message(STATUS "ccache found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
+        message(STATUS "${GGML_CCACHE_VARIANT} found, compilation results will be cached. Disable with GGML_CCACHE=OFF.")
     else()
         message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with GGML_CCACHE=OFF")
     endif ()
@@ -208,6 +214,7 @@ add_library(ggml-base
             ../include/ggml-backend.h
             ../include/ggml-cpp.h
             ../include/ggml-opt.h
+            ../include/gguf.h
             ggml.c
             ggml-alloc.c
             ggml-backend.cpp
@@ -215,7 +222,8 @@ add_library(ggml-base
             ggml-threading.cpp
             ggml-threading.h
             ggml-quants.c
-            ggml-quants.h)
+            ggml-quants.h
+            gguf.cpp)
 
 target_include_directories(ggml-base PRIVATE .)
 
@@ -248,6 +256,17 @@ function(ggml_add_backend_library backend)
         target_compile_definitions(${backend} PRIVATE GGML_BACKEND_BUILD)
         target_compile_definitions(${backend} PUBLIC  GGML_BACKEND_SHARED)
     endif()
+
+    if(NOT GGML_AVAILABLE_BACKENDS)
+        set(GGML_AVAILABLE_BACKENDS "${backend}"
+            CACHE INTERNAL "List of backends for cmake package")
+    else()
+        list(FIND GGML_AVAILABLE_BACKENDS "${backend}" has_backend)
+        if(has_backend EQUAL -1)
+            set(GGML_AVAILABLE_BACKENDS "${GGML_AVAILABLE_BACKENDS};${backend}"
+                CACHE INTERNAL "List of backends for cmake package")
+        endif()
+    endif()
 endfunction()
 
 function(ggml_add_backend backend)
@@ -293,7 +312,7 @@ if (GGML_CPU_ALL_VARIANTS)
     ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 FMA AVX512)
     ggml_add_cpu_backend_variant(icelake        AVX F16C AVX2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
     ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 FMA AVX_VNNI)
-else ()
+elseif (GGML_CPU)
     ggml_add_cpu_backend_variant_impl("")
 endif()
 
diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c
index 8dc8226a..7244a9cb 100644
--- a/ml/backend/ggml/ggml/src/ggml-alloc.c
+++ b/ml/backend/ggml/ggml/src/ggml-alloc.c
@@ -37,6 +37,7 @@ static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml
     return true;
 }
 
+// ops that return true for this function must not use restrict pointers for their backend implementations
 static bool ggml_op_can_inplace(enum ggml_op op) {
     switch (op) {
         case GGML_OP_SCALE:
@@ -52,8 +53,12 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
         case GGML_OP_LOG:
         case GGML_OP_UNARY:
         case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
+        case GGML_OP_SILU_BACK:
         case GGML_OP_RMS_NORM:
+        case GGML_OP_RMS_NORM_BACK:
         case GGML_OP_SOFT_MAX:
+        case GGML_OP_SOFT_MAX_BACK:
             return true;
 
         default:
@@ -984,19 +989,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
             this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
         }
 
-        if (this_size > max_size) {
-            GGML_LOG_ERROR("%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
-                    __func__, t->name,
-                    ggml_backend_buft_name(buft),
-                    this_size, max_size);
-            for (size_t i = 0; i < n_buffers; i++) {
-                ggml_backend_buffer_free(buffers[i]);
-            }
-            free(buffers);
-            return NULL;
-        }
-
-        if ((cur_buf_size + this_size) > max_size) {
+        if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) {
             // allocate tensors in the current buffer
             if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {
                 return NULL;
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-impl.h b/ml/backend/ggml/ggml/src/ggml-backend-impl.h
index 36d72e95..d1c2d76d 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-backend-impl.h
@@ -208,7 +208,6 @@ extern "C" {
 
     // Internal backend registry API
     GGML_API void ggml_backend_register(ggml_backend_reg_t reg);
-    GGML_API void ggml_backend_device_register(ggml_backend_dev_t device);
 
     // Add backend dynamic loading support to the backend
 
diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
index e35a6936..c854e6bb 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
@@ -552,4 +552,9 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
     ggml_backend_load_best("opencl", silent, dir_path);
     ggml_backend_load_best("musa", silent, dir_path);
     ggml_backend_load_best("cpu", silent, dir_path);
+    // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
+    const char * backend_path = std::getenv("GGML_BACKEND_PATH");
+    if (backend_path) {
+        ggml_backend_load(backend_path);
+    }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-backend.cpp b/ml/backend/ggml/ggml/src/ggml-backend.cpp
index a12172dc..1ca40b2c 100644
--- a/ml/backend/ggml/ggml/src/ggml-backend.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-backend.cpp
@@ -763,7 +763,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
         if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
             int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
             // check if a backend with higher prio wants to offload the op
-            if (src_backend_id == sched->n_backends - 1) {
+            if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
                 for (int b = 0; b < src_backend_id; b++) {
                     if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
                         SET_CAUSE(tensor, "1.off");
diff --git a/ml/backend/ggml/ggml/src/ggml-common.h b/ml/backend/ggml/ggml/src/ggml-common.h
index f13fd4de..6c02b69e 100644
--- a/ml/backend/ggml/ggml/src/ggml-common.h
+++ b/ml/backend/ggml/ggml/src/ggml-common.h
@@ -473,7 +473,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
     240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
 GGML_TABLE_END()
 
-//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
 GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
     0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
@@ -508,7 +507,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
     0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
 GGML_TABLE_END()
-//#endif
 
 
 GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
index 6b3641c4..aa5ad5d8 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/CMakeLists.txt
@@ -111,14 +111,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 function(check_arm_feature tag code)
                     set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
                     set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+${tag}")
-                    check_cxx_source_runs(
-                        "${code}"
-                        GGML_MACHINE_SUPPORTS_${tag}
-                    )
+                    check_cxx_source_runs("${code}" GGML_MACHINE_SUPPORTS_${tag})
                     if (GGML_MACHINE_SUPPORTS_${tag})
                         set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
                     else()
-                        set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
+                        set(CMAKE_REQUIRED_FLAGS "${ARM_MCPU_FLAG}+no${tag}")
+                        check_cxx_source_compiles("int main() { return 0; }" GGML_MACHINE_SUPPORTS_no${tag})
+                        if (GGML_MACHINE_SUPPORTS_no${tag})
+                            set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
+                        endif()
                     endif()
                     set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
                 endfunction()
@@ -126,6 +127,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
                 check_arm_feature(dotprod "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }")
                 check_arm_feature(i8mm    "#include \nint main() { int8x16_t _a, _b; volatile int32x4_t _s = vmmlaq_s32(_s, _a, _b); return 0; }")
                 check_arm_feature(sve     "#include \nint main()  { svfloat32_t _a, _b; volatile svfloat32_t _c = svadd_f32_z(svptrue_b8(), _a, _b); return 0; }")
+                check_arm_feature(sme     "#include \n__arm_locally_streaming int main() { __asm__ volatile(\"smstart; smstop;\"); return 0; }")
 
                 list(APPEND ARCH_FLAGS "${ARM_MCPU_FLAG}${ARM_MCPU_FLAG_FIX}")
             else()
@@ -150,7 +152,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             if (ARM_FEATURE_RESULT)
                 message(WARNING "Failed to get ARM features")
             else()
-                foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC)
+                foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
                     string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
                     if (NOT ${feature_pos} EQUAL -1)
                         message(STATUS "ARM feature ${feature} enabled")
@@ -308,6 +310,27 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
         if (GGML_RVV)
             list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
         endif()
+    elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
+        message(STATUS "s390x detected")
+        file(READ "/proc/cpuinfo" CPUINFO_CONTENTS)
+        string(REGEX REPLACE "machine[ \t\r\n]*=[ \t\r\n]*([0-9]+)" "\\1" S390X_M ${CPUINFO_CONTENTS})
+
+        # TODO: Separation to determine activation of VX/VXE/VXE2
+        if (${S390X_M} MATCHES "8561|8562")
+            message(STATUS "z15 target")
+            list(APPEND ARCH_FLAGS -march=z15 -mtune=z15)
+        elseif (${S390X_M} MATCHES "3931")
+            message(STATUS "z16 target")
+            list(APPEND ARCH_FLAGS -march=z16 -mtune=z16)
+        else()
+            message(STATUS "Unknown target")
+            message(WARNING "Unknown target. If you are compiling for z14 and earlier, you might have to add -DGGML_VXE=OFF.")
+            list(APPEND ARCH_FLAGS -march=native -mtune=native)
+        endif()
+
+        if (GGML_VXE)
+            list(APPEND ARCH_FLAGS -mvx -mzvector)
+        endif()
     else()
         message(STATUS "Unknown architecture")
     endif()
@@ -316,6 +339,94 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
         target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_AARCH64)
     endif()
 
+    if (GGML_CPU_KLEIDIAI)
+        message(STATUS "Using KleidiAI optimized kernels if applicable")
+
+        # Disable the KleidiAI tests
+        set(KLEIDIAI_BUILD_TESTS  OFF)
+
+        # Fetch KleidiAI sources:
+        include(FetchContent)
+        set(KLEIDIAI_COMMIT_TAG "v1.3.0")
+        set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
+        set(KLEIDIAI_ARCHIVE_MD5  "060bd2dc64642b091f461cc8dd7426d9")
+
+        if (POLICY CMP0135)
+            cmake_policy(SET CMP0135 NEW)
+        endif()
+
+        FetchContent_Declare(KleidiAI_Download
+            URL ${KLEIDIAI_DOWNLOAD_URL}
+            DOWNLOAD_EXTRACT_TIMESTAMP NEW
+            URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5})
+
+        FetchContent_MakeAvailable(KleidiAI_Download)
+        FetchContent_GetProperties(KleidiAI_Download
+            SOURCE_DIR  KLEIDIAI_SRC
+            POPULATED   KLEIDIAI_POPULATED)
+
+        if (NOT KLEIDIAI_POPULATED)
+            message(FATAL_ERROR "KleidiAI source downloaded failed.")
+        endif()
+
+        add_compile_definitions(GGML_USE_CPU_KLEIDIAI)
+
+        # Remove kleidiai target after fetching it
+        if (TARGET kleidiai)
+            set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE)
+        endif()
+
+        list(APPEND GGML_CPU_SOURCES
+            ggml-cpu/kleidiai/kleidiai.cpp
+            ggml-cpu/kleidiai/kernels.cpp
+            ggml-cpu/kleidiai/kleidiai.h
+            ggml-cpu/kleidiai/kernels.h
+            )
+
+        # KleidiAI
+        include_directories(
+            ${KLEIDIAI_SRC}/
+            ${KLEIDIAI_SRC}/kai/
+            ${KLEIDIAI_SRC}/kai/ukernels/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
+            ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
+
+        set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}")
+        if (NOT ARCH_FLAGS_TEMP)
+            string(REGEX MATCH "-march=[^ ]+" ARCH_FLAGS_TEMP "${CMAKE_C_FLAGS}")
+        endif()
+        string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED)
+        string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED)
+        string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED)
+
+        set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
+
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c)
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c)
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c)
+        list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
+
+        if (NOT DOTPROD_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
+        endif()
+
+        if (NOT I8MM_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
+        endif()
+
+        if (NOT SME_ENABLED MATCHES -1)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c)
+            list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c)
+            set(PRIVATE_ARCH_FLAGS "${PRIVATE_ARCH_FLAGS}+sve+sve2")
+        endif()
+
+        set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}")
+        list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES})
+    endif()
+
     message(STATUS "Adding CPU backend variant ${GGML_CPU_NAME}: ${ARCH_FLAGS} ${ARCH_DEFINITIONS}")
     target_sources(${GGML_CPU_NAME} PRIVATE ${GGML_CPU_SOURCES})
     target_compile_options(${GGML_CPU_NAME} PRIVATE ${ARCH_FLAGS})
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
index 622c63f1..b311a5b1 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
@@ -4169,6 +4169,8 @@ static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(g
     buffer->buft              = buft;
     buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
     buffer->iface.set_tensor  = ggml_backend_cpu_aarch64_buffer_set_tensor;
+    buffer->iface.get_tensor  = nullptr;
+    buffer->iface.cpy_tensor  = nullptr;
     return buffer;
 }
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
index d71076ad..7f7d210c 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h
@@ -59,6 +59,15 @@ struct ggml_compute_params {
 #endif
 #endif
 
+#if defined(__s390x__) && defined(__VEC__)
+#ifndef __VXE__
+#define __VXE__
+#endif
+#ifndef __VXE2__
+#define __VXE2__
+#endif
+#endif
+
 #if defined(__ARM_FEATURE_SVE)
 #include 
 #include 
@@ -359,22 +368,158 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
 #endif
 #endif
 
-#if defined(__loongarch_asx)
+#if defined(__VXE__) || defined(__VXE2__)
+#include 
 
-typedef union {
-    int32_t i;
-    float f;
-} ft_union;
+#define vec_neg(a)    (-(a))                // Vector Negate
+#define vec_add(a, b) ((a) + (b))           // Vector Add
+#define vec_sub(a, b) ((a) - (b))           // Vector Subtract
+#define vec_mul(a, b) ((a) * (b))           // Vector Multiply
+#define vec_div(a, b) ((a) / (b))           // Vector Divide
+#define vec_sl(a, b)  ((a) << (b))          // Vector Shift Left
+#define vec_sra(a, b) ((a) >> (b))          // Vector Shift Right
+#define vec_sr(a, b)  ((a) >> (b))          // Vector Shift Right Algebraic
+#define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet
+#define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet
 
-/* float type data load instructions */
-static __m128 __lsx_vreplfr2vr_s(float val) {
-    ft_union fi_tmpval = {.f = val};
-    return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
+#ifndef vec_and
+#define vec_and(a, b) ((a) & (b)) // Vector AND
+#endif
+
+#ifndef vec_or
+#define vec_or(a, b)  ((a) | (b)) // Vector OR
+#endif
+
+#ifndef vec_xor
+#define vec_xor(a, b) ((a) ^ (b)) // Vector XOR
+#endif
+
+typedef signed char char8x16_t __attribute__((vector_size(16)));
+typedef unsigned char uchar8x16_t __attribute__((vector_size(16)));
+
+typedef int8_t  int8x16_t __attribute__((vector_size(16)));
+typedef int16_t int16x8_t __attribute__((vector_size(16)));
+typedef int32_t int32x4_t __attribute__((vector_size(16)));
+
+typedef uint8_t  uint8x16_t __attribute__((vector_size(16)));
+typedef uint16_t uint16x8_t __attribute__((vector_size(16)));
+typedef uint32_t uint32x4_t __attribute__((vector_size(16)));
+
+typedef float float32x4_t __attribute__((vector_size(16)));
+typedef double double64x2_t __attribute((vector_size(16)));
+
+typedef signed long long long64x2_t __attribute((vector_size(16)));
+typedef unsigned long long ulong64x2_t __attribute__((vector_size(16)));
+
+typedef struct ggml_uint8x16x2_t {
+    uint8x16_t val[2];
+} ggml_uint8x16x2_t;
+
+inline static ggml_uint8x16x2_t ggml_vec_xl_u8x2(const uint8_t * ptr) {
+    ggml_uint8x16x2_t res;
+
+    res.val[0] = vec_xl( 0, ptr);
+    res.val[1] = vec_xl(16, ptr);
+
+    return res;
 }
 
-static __m256 __lasx_xvreplfr2vr_s(float val) {
-    ft_union fi_tmpval = {.f = val};
-    return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
+typedef struct ggml_uint8x16x4_t {
+    uint8x16_t val[4];
+} ggml_uint8x16x4_t;
+
+inline static ggml_uint8x16x4_t ggml_vec_xl_u8x4(const uint8_t * ptr) {
+    ggml_uint8x16x4_t res;
+
+    res.val[0] = vec_xl( 0, ptr);
+    res.val[1] = vec_xl(16, ptr);
+    res.val[2] = vec_xl(32, ptr);
+    res.val[3] = vec_xl(48, ptr);
+
+    return res;
+}
+
+typedef struct ggml_int8x16x4_t {
+    int8x16_t val[4];
+} ggml_int8x16x4_t;
+
+inline static ggml_int8x16x4_t ggml_vec_xl_s8x4(const int8_t * ptr) {
+    ggml_int8x16x4_t res;
+
+    res.val[0] = vec_xl( 0, ptr);
+    res.val[1] = vec_xl(16, ptr);
+    res.val[2] = vec_xl(32, ptr);
+    res.val[3] = vec_xl(48, ptr);
+
+    return res;
+}
+
+typedef struct ggml_int16x8x2_t {
+    int16x8_t val[2];
+} ggml_int16x8x2_t;
+
+inline static ggml_int16x8x2_t ggml_vec_xl_s16x2(const int16_t * ptr) {
+    ggml_int16x8x2_t res;
+
+    res.val[0] = vec_xl( 0, ptr);
+    res.val[1] = vec_xl(16, ptr);
+
+    return res;
+}
+
+/*
+    ! WARNING: Very slow. Use vec_perm if possible. Refer to iq4_xs
+    !          or iq4_nl for example implementation.
+*/
+inline static int8x16_t ggml_vec_tbl(int8x16_t a, uint8x16_t b) {
+    int8x16_t res;
+
+    res[ 0] = a[b[ 0]];
+    res[ 1] = a[b[ 1]];
+    res[ 2] = a[b[ 2]];
+    res[ 3] = a[b[ 3]];
+    res[ 4] = a[b[ 4]];
+    res[ 5] = a[b[ 5]];
+    res[ 6] = a[b[ 6]];
+    res[ 7] = a[b[ 7]];
+    res[ 8] = a[b[ 8]];
+    res[ 9] = a[b[ 9]];
+    res[10] = a[b[10]];
+    res[11] = a[b[11]];
+    res[12] = a[b[12]];
+    res[13] = a[b[13]];
+    res[14] = a[b[14]];
+    res[15] = a[b[15]];
+
+    return res;
+}
+
+inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
+    const uchar8x16_t v_maske = {  0,  1,  4,  5,  8,  9, 12, 13,
+                                  16, 17, 20, 21, 24, 25, 28, 29 };
+
+    const int16x8_t v_abo = vec_pack((int32x4_t)a, (int32x4_t)b);
+    const int16x8_t v_abe = vec_perm(a, b, v_maske);
+    return v_abo + v_abe;
+}
+
+inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
+    const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
+    return acc + (vec_unpackh(p) + vec_unpackl(p));
+}
+
+#endif
+
+#if defined(__loongarch_asx)
+/* float type data load instructions */
+static __m128 __lsx_vreplfr2vr_s(const float val) {
+    v4f32 res = {val, val, val, val};
+    return (__m128)res;
+}
+
+static __m256 __lasx_xvreplfr2vr_s(const float val) {
+    v8f32 res = {val, val, val, val, val, val, val, val};
+    return (__m256)res;
 }
 #endif
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c
index 8e147226..8d5e3e20 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-quants.c
@@ -297,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
 static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
 #endif
 
+#if defined(__loongarch_sx)
+
+static __m128i lsx_packs_w(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_w(a, 15);
+    tmp1 = __lsx_vsat_w(b, 15);
+    return __lsx_vpickev_h(tmp1, tmp);
+}
+
+static __m128i lsx_packs_h(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_h(a, 7);
+    tmp1 = __lsx_vsat_h(b, 7);
+    return __lsx_vpickev_b(tmp1, tmp);
+}
+
+static __m128i lsx_packus_h(__m128i a, __m128i b) {
+    __m128i tmp, tmp1;
+    tmp = __lsx_vsat_hu(a, 7);
+    tmp1 = __lsx_vsat_hu(b, 7);
+    return __lsx_vpickev_b(tmp1, tmp);
+}
+
+static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
+    __m128i tmp1, tmp2;
+    tmp1 = __lsx_vmulwev_h_b(a, b);
+    tmp2 = __lsx_vmulwod_h_b(a, b);
+    return __lsx_vsadd_h(tmp1, tmp2);
+}
+
+static __m128i lsx_madd_h(__m128i a, __m128i b) {
+    __m128i tmp1, tmp2;
+    tmp1 = __lsx_vmulwev_w_h(a, b);
+    tmp2 = __lsx_vmulwod_w_h(a, b);
+    return __lsx_vadd_w(tmp1, tmp2);
+}
+
+static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
+    v4i32 __ret = {d, c, b, a};
+    return (__m128i)__ret;
+}
+
+static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
+    __m128i mask_f, zero, tmp0, tmp2, mask;
+    int f = 0x8f;
+    mask_f = __lsx_vreplgr2vr_b(f);
+    zero = __lsx_vldi(0);
+    tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
+    tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive
+    mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
+    tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
+    return __lsx_vshuf_b(a, zero, tmp2);
+}
+
+static __m128i lsx_hadd_h(__m128i a, __m128i b) {
+    __m128i tmp1 = __lsx_vpickev_h(b, a);
+    __m128i tmp2 = __lsx_vpickod_h(b, a);
+    return __lsx_vadd_h(tmp1, tmp2);
+}
+
+static __m128i lsx_hadd_w(__m128i a, __m128i b) {
+    __m128i tmp1 = __lsx_vpickev_w(b, a);
+    __m128i tmp2 = __lsx_vpickod_w(b, a);
+    return __lsx_vadd_w(tmp1, tmp2);
+}
+
+static __m128 lsx_hadd_s(__m128 a, __m128 b) {
+    __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
+    __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
+
+    return __lsx_vfadd_s(tmp1, tmp2);
+}
+
+static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
+    __m128 res_0 =lsx_hadd_s(a, b);
+    __m128 res_1 =lsx_hadd_s(c, d);
+    __m128 res =lsx_hadd_s(res_0, res_1);
+    res =lsx_hadd_s(res, res);
+    res =lsx_hadd_s(res, res);
+
+    return ((v4f32)res)[0];
+}
+#endif
+
 #if defined(__loongarch_asx)
 
 #ifdef __clang__
@@ -395,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
     return (__m256i)__ret;
 }
 
-static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
-    v4i32 __ret = {d, c, b, a};
-    return (__m128i)__ret;
-}
-
 static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
     v4i64 __ret = {d, c, b, a};
     return (__m256i)__ret;
@@ -409,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
     return lasx_set_q(x, y);
 }
 
-static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
-    __m128i mask_f, zero, tmp0, tmp2, mask;
-    int f = 0x8f;
-    mask_f = __lsx_vreplgr2vr_b(f);
-    zero = __lsx_vldi(0);
-    tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
-    tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or  with 0x10 prepare for positive
-    mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
-    tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
-    return __lsx_vshuf_b(a, zero, tmp2);
-}
-
 static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
     __m256i mask_f, zero, tmp0, tmp2, mask;
     int f = 0x8f;
@@ -434,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
 }
 
 static __m256i lasx_extu8_16(__m128i a) {
-    __m128i zero = __lsx_vldi(0);
-    __m128i vlo = __lsx_vilvl_b(zero, a);
-    __m128i vhi = __lsx_vilvh_b(zero, a);
-    return lasx_set_q(vhi, vlo);
+    return __lasx_vext2xv_hu_bu(____m256i(a));
 }
 
 static __m256i lasx_ext8_16(__m128i a) {
-     __m128i sign = __lsx_vslti_b(a, 0);
-     __m128i vlo = __lsx_vilvl_b(sign, a);
-     __m128i vhi = __lsx_vilvh_b(sign, a);
-     return lasx_set_q(vhi, vlo);
+    return __lasx_vext2xv_h_b(____m256i(a));
 }
 
 static __m256i lasx_ext16_32(__m128i a) {
-    __m256i tmp1;
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
-    tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
-    return tmp1;
+    return __lasx_vext2xv_w_h(____m256i(a));
 }
 
 static __m128i lasx_extracti128( __m256i a, int pos) {
@@ -482,25 +534,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
     return ret;
 }
 
-static __m128i lsx_hadd_h(__m128i a, __m128i b) {
-    __m128i tmp1 = __lsx_vpickev_h(b, a);
-    __m128i tmp2 = __lsx_vpickod_h(b, a);
-    return __lsx_vadd_h(tmp1, tmp2);
-}
-
-static __m128i lsx_hadd_w(__m128i a, __m128i b) {
-    __m128i tmp1 = __lsx_vpickev_w(b, a);
-    __m128i tmp2 = __lsx_vpickod_w(b, a);
-    return __lsx_vadd_w(tmp1, tmp2);
-}
-
-static __m128 lsx_hadd_s(__m128 a, __m128 b) {
-    __m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
-    __m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
-
-    return __lsx_vfadd_s(tmp1, tmp2);
-}
-
 static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
     __m256i tmp1, tmp2;
     tmp1 = __lasx_xvmulwev_h_b(a, b);
@@ -529,40 +562,39 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
     return __lasx_xvpickev_b(tmp1, tmp);
 }
 
-static __m128i lsx_packs_w(__m128i a, __m128i b) {
-    __m128i tmp, tmp1;
-    tmp = __lsx_vsat_w(a, 15);
-    tmp1 = __lsx_vsat_w(b, 15);
-    return __lsx_vpickev_h(tmp1, tmp);
+static inline __m256i lasx_madd_h_b(__m256i a, __m256i b) {
+    __m256i tmp1, tmp2;
+    tmp1 = __lasx_xvmulwev_h_b(a, b);
+    tmp2 = __lasx_xvmulwod_h_b(a, b);
+    return __lasx_xvadd_h(tmp1, tmp2);
 }
 
-static __m128i lsx_packs_h(__m128i a, __m128i b) {
-    __m128i tmp, tmp1;
-    tmp = __lsx_vsat_h(a, 7);
-    tmp1 = __lsx_vsat_h(b, 7);
-    return __lsx_vpickev_b(tmp1, tmp);
+static inline __m256i lasx_xvrepl128vei_h(__m256i a, const unsigned int b) {
+    switch (b) {
+        case 0: return __lasx_xvrepl128vei_h(a, 0);
+        case 1: return __lasx_xvrepl128vei_h(a, 1);
+        case 2: return __lasx_xvrepl128vei_h(a, 2);
+        case 3: return __lasx_xvrepl128vei_h(a, 3);
+        case 4: return __lasx_xvrepl128vei_h(a, 4);
+        case 5: return __lasx_xvrepl128vei_h(a, 5);
+        case 6: return __lasx_xvrepl128vei_h(a, 6);
+        case 7: return __lasx_xvrepl128vei_h(a, 7);
+        default: __builtin_unreachable();
+    }
 }
 
-static __m128i lsx_packus_h(__m128i a, __m128i b) {
-    __m128i tmp, tmp1;
-    tmp = __lsx_vsat_hu(a, 7);
-    tmp1 = __lsx_vsat_hu(b, 7);
-    return __lsx_vpickev_b(tmp1, tmp);
-}
-
-
-static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
-    __m128i tmp1, tmp2;
-    tmp1 = __lsx_vmulwev_h_b(a, b);
-    tmp2 = __lsx_vmulwod_h_b(a, b);
-    return __lsx_vsadd_h(tmp1, tmp2);
-}
-
-static __m128i lsx_madd_h(__m128i a, __m128i b) {
-    __m128i tmp1, tmp2;
-    tmp1 = __lsx_vmulwev_w_h(a, b);
-    tmp2 = __lsx_vmulwod_w_h(a, b);
-    return __lsx_vadd_w(tmp1, tmp2);
+static inline __m256i lasx_xvandi_b_bit(__m256i a, const unsigned int b) {
+    switch (b) {
+        case 0: return __lasx_xvandi_b(a, 1 << 0);
+        case 1: return __lasx_xvandi_b(a, 1 << 1);
+        case 2: return __lasx_xvandi_b(a, 1 << 2);
+        case 3: return __lasx_xvandi_b(a, 1 << 3);
+        case 4: return __lasx_xvandi_b(a, 1 << 4);
+        case 5: return __lasx_xvandi_b(a, 1 << 5);
+        case 6: return __lasx_xvandi_b(a, 1 << 6);
+        case 7: return __lasx_xvandi_b(a, 1 << 7);
+        default: __builtin_unreachable();
+    }
 }
 
 // multiply int8_t, add results pairwise twice
@@ -580,12 +612,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
 // horizontally add 8 floats
 static inline float hsum_float_8(const __m256 x) {
     __m128 res = lasx_extractf128(x, 1);
-    ft_union tmp;
     res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
     res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
     res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
-    tmp.i = __lsx_vpickve2gr_w(res, 0);
-    return tmp.f;
+    return ((v4f32)res)[0];
 }
 
 // horizontally add 8 int32_t
@@ -661,13 +691,8 @@ static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy)
 
 // multiply int8_t, add results pairwise twice and return as float vector
 static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
-
-    // Get absolute values of x vectors
-    const __m256i ax = __lasx_xvsigncov_b(x, x);
-    // Sign the values of the y vectors
-    const __m256i sy = __lasx_xvsigncov_b(x, y);
-
-    return mul_sum_us8_pairs_float(ax, sy);
+    const __m256i dot = lasx_madd_h_b(x, y);
+    return sum_i16_pairs_float(dot);
 }
 
 static inline __m128i packNibbles( __m256i bytes ) {
@@ -747,7 +772,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
             y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
         }
     }
-#elif defined(__wasm_simd128__)
+#elif defined __wasm_simd128__
     for (int i = 0; i < nb; i++) {
         v128_t srcv [8];
         v128_t asrcv[8];
@@ -927,7 +952,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
 
 #elif defined(__loongarch_asx)
     for (int i = 0; i < nb; i++) {
-        ft_union fi;
         __m256 v0 = (__m256)__lasx_xvld( x , 0);
         __m256 v1 = (__m256)__lasx_xvld( x , 32);
         __m256 v2 = (__m256)__lasx_xvld( x , 64);
@@ -945,8 +969,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
         __m128 tmp = max4;
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
-        fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
-        const float max_scalar = fi.f;
+        const float max_scalar = ((v4f32)max4)[0];
 
         // Quantize these floats
         const float d = max_scalar / 127.f;
@@ -988,6 +1011,38 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
         __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
 
     }
+#elif defined(__VXE__) || defined(__VXE2__)
+    for (int i = 0; i < nb; i++) {
+        __vector float srcv [8];
+        __vector float asrcv[8];
+        __vector float amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+                                   vec_extract(amaxv[0], 1)),
+                               MAX(vec_extract(amaxv[0], 2),
+                                   vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f / d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        for (int j = 0; j < 8; j++) {
+            const __vector float v = vec_mul(srcv[j], vec_splats(id));
+            const __vector int32_t vi = vec_signed(v);
+
+            y[i].qs[4*j + 0] = vec_extract(vi, 0);
+            y[i].qs[4*j + 1] = vec_extract(vi, 1);
+            y[i].qs[4*j + 2] = vec_extract(vi, 2);
+            y[i].qs[4*j + 3] = vec_extract(vi, 3);
+        }
+    }
 #else
     GGML_UNUSED(nb);
     // scalar
@@ -1037,7 +1092,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
 
         y[i].s = GGML_FP32_TO_FP16(d * vaddvq_s32(accv));
     }
-#elif defined(__wasm_simd128__)
+#elif defined __wasm_simd128__
     for (int i = 0; i < nb; i++) {
         v128_t srcv [8];
         v128_t asrcv[8];
@@ -1251,7 +1306,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
 
 #elif defined(__loongarch_asx)
     for (int i = 0; i < nb; i++) {
-        ft_union ft;
         __m256 v0 = (__m256)__lasx_xvld( x , 0 );
         __m256 v1 = (__m256)__lasx_xvld( x , 32 );
         __m256 v2 = (__m256)__lasx_xvld( x , 64 );
@@ -1269,8 +1323,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
         __m128 tmp = max4;
         max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
-        ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
-        const float max_scalar = ft.f;
+        const float max_scalar = ((v4f32)max4)[0];
 
         // Quantize these floats
         const float d = max_scalar / 127.f;
@@ -1316,6 +1369,44 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
         __lsx_vst(ni0, (__m128i *)(y[i].qs +  0), 0);
         __lsx_vst(ni4, (__m128i *)(y[i].qs + 16), 0);
     }
+#elif defined(__VXE__) || defined(__VXE2__)
+    for (int i = 0; i < nb; i++) {
+        __vector float srcv [8];
+        __vector float asrcv[8];
+        __vector float amaxv[8];
+
+        for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
+        for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
+        for (int j = 0; j < 4; j++) amaxv[2*j] = vec_max(asrcv[2*j], asrcv[2*j+1]);
+        for (int j = 0; j < 2; j++) amaxv[4*j] = vec_max(amaxv[4*j], amaxv[4*j+2]);
+        for (int j = 0; j < 1; j++) amaxv[8*j] = vec_max(amaxv[8*j], amaxv[8*j+4]);
+
+        const float amax = MAX(MAX(vec_extract(amaxv[0], 0),
+                                   vec_extract(amaxv[0], 1)),
+                               MAX(vec_extract(amaxv[0], 2),
+                                   vec_extract(amaxv[0], 3)));
+
+        const float d = amax / ((1 << 7) - 1);
+        const float id = d ? 1.0f / d : 0.0f;
+
+        y[i].d = GGML_FP32_TO_FP16(d);
+
+        __vector int32_t acc = vec_splats(0);
+
+        for (int j = 0; j < 8; j++) {
+            const __vector float v = vec_mul(srcv[j], vec_splats(id));
+            const __vector int32_t vi = vec_signed(v);
+
+            y[i].qs[4*j + 0] = vec_extract(vi, 0);
+            y[i].qs[4*j + 1] = vec_extract(vi, 1);
+            y[i].qs[4*j + 2] = vec_extract(vi, 2);
+            y[i].qs[4*j + 3] = vec_extract(vi, 3);
+
+            acc = vec_add(acc, vi);
+        }
+
+        y[i].s = GGML_FP32_TO_FP16(d * (acc[0] + acc[1] + acc[2] + acc[3]));
+    }
 #else
     GGML_UNUSED(nb);
     // scalar
@@ -1653,7 +1744,87 @@ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -1
 //===================================== Q8_K ==============================================
 
 void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
+#ifdef __wasm_simd128__
+    assert(k % QK_K == 0);
+    const int64_t nb = k / QK_K;
+    block_q8_K * restrict yc = y; // Cast to proper type
+
+    for (int i = 0; i < nb; i++) {
+        const float * x_block = x + i * QK_K;
+
+        v128_t min_vec = wasm_v128_load(x_block);
+        v128_t max_vec = min_vec;
+
+        for (int j = 4; j < QK_K; j += 4) {
+            v128_t x_vec = wasm_v128_load(x_block + j);
+            max_vec = wasm_f32x4_pmax(max_vec, x_vec);
+            min_vec = wasm_f32x4_pmin(min_vec, x_vec);
+        }
+        max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 2, 3, 0, 1));
+        max_vec = wasm_f32x4_pmax(max_vec, wasm_i32x4_shuffle(max_vec, max_vec, 1, 0, 3, 2));
+        min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 2, 3, 0, 1));
+        min_vec = wasm_f32x4_pmin(min_vec, wasm_i32x4_shuffle(min_vec, min_vec, 1, 0, 3, 2));
+        float max = wasm_f32x4_extract_lane(max_vec, 0);
+        float min = wasm_f32x4_extract_lane(min_vec, 0);
+        float amax = -min > max ? min : max;
+
+        if (amax == 0.0f) {
+            yc[i].d = 0.0f;
+            const v128_t zero = wasm_i8x16_splat(0);
+            for (int j = 0; j < QK_K; j += 16) {
+                wasm_v128_store(yc[i].qs + j, zero);
+            }
+            continue;
+        }
+
+        const float iscale = -127.0f / amax;
+        const v128_t scale_vec = wasm_f32x4_splat(iscale);
+
+        // Process 16 elements per iteration
+        for (int j = 0, jb = 0; j < QK_K; j += 16, jb++) {
+            // Load and quantize 16 floats
+            v128_t x0 = wasm_v128_load(x_block + j);
+            v128_t x1 = wasm_v128_load(x_block + j + 4);
+            v128_t x2 = wasm_v128_load(x_block + j + 8);
+            v128_t x3 = wasm_v128_load(x_block + j + 12);
+
+            v128_t q0 = wasm_f32x4_nearest(wasm_f32x4_mul(x0, scale_vec));
+            v128_t q1 = wasm_f32x4_nearest(wasm_f32x4_mul(x1, scale_vec));
+            v128_t q2 = wasm_f32x4_nearest(wasm_f32x4_mul(x2, scale_vec));
+            v128_t q3 = wasm_f32x4_nearest(wasm_f32x4_mul(x3, scale_vec));
+
+            // Convert to i32 with saturation
+            v128_t i0 = wasm_i32x4_trunc_sat_f32x4(q0);
+            v128_t i1 = wasm_i32x4_trunc_sat_f32x4(q1);
+            v128_t i2 = wasm_i32x4_trunc_sat_f32x4(q2);
+            v128_t i3 = wasm_i32x4_trunc_sat_f32x4(q3);
+
+            // Pack into 16 i8 values
+            v128_t i8 = wasm_i8x16_narrow_i16x8(
+                wasm_i16x8_narrow_i32x4(i0, i1),
+                wasm_i16x8_narrow_i32x4(i2, i3)
+            );
+            wasm_v128_store(yc[i].qs + j, i8);
+
+            // Calculate bsums using SIMD
+            v128_t sum16 = wasm_i16x8_add(
+                wasm_i16x8_extend_low_i8x16(i8),
+                wasm_i16x8_extend_high_i8x16(i8)
+            );
+            v128_t sum32 = wasm_i32x4_add(
+                wasm_i32x4_extend_low_i16x8(sum16),
+                wasm_i32x4_extend_high_i16x8(sum16)
+            );
+            sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 2, 3, 0, 1));
+            sum32 = wasm_i32x4_add(sum32, wasm_i32x4_shuffle(sum32, sum32, 1, 0, 3, 2));
+            yc[i].bsums[jb] = wasm_i32x4_extract_lane(sum32, 0);
+        }
+
+        yc[i].d = 1.0f / iscale;
+    }
+#else
     quantize_row_q8_K_ref(x, y, k);
+#endif
 }
 
 //===================================== Dot products =================================
@@ -2011,6 +2182,94 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined __wasm_simd128__
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    const v128_t m4b = wasm_i8x16_splat(0x0F);
+    const v128_t s8b = wasm_i8x16_splat(0x8);
+
+    for (; ib + 1 < nb; ib += 2) {
+        const block_q4_0 * restrict x0 = &x[ib];
+        const block_q4_0 * restrict x1 = &x[ib + 1];
+        const block_q8_0 * restrict y0 = &y[ib];
+        const block_q8_0 * restrict y1 = &y[ib + 1];
+
+        // Load and process x0
+        v128_t v0_0 = wasm_v128_load(x0->qs);
+        v128_t v0_0l = wasm_v128_and(v0_0, m4b);
+        v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
+        v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
+        v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
+
+        // Load y0 vectors
+        v128_t y0_l = wasm_v128_load(y0->qs);
+        v128_t y0_h = wasm_v128_load(y0->qs + 16);
+
+        // Extend to i16x8 and compute dot products
+        v128_t dx0l = wasm_i16x8_extend_low_i8x16(v0_0ls);
+        v128_t dx0h = wasm_i16x8_extend_high_i8x16(v0_0ls);
+        v128_t dx0hl = wasm_i16x8_extend_low_i8x16(v0_0hs);
+        v128_t dx0hh = wasm_i16x8_extend_high_i8x16(v0_0hs);
+
+        v128_t dy0ll = wasm_i16x8_extend_low_i8x16(y0_l);
+        v128_t dy0lh = wasm_i16x8_extend_high_i8x16(y0_l);
+        v128_t dy0hl = wasm_i16x8_extend_low_i8x16(y0_h);
+        v128_t dy0hh = wasm_i16x8_extend_high_i8x16(y0_h);
+
+        v128_t dp0 = wasm_i32x4_add(
+            wasm_i32x4_add(
+                wasm_i32x4_dot_i16x8(dx0l, dy0ll),
+                wasm_i32x4_dot_i16x8(dx0h, dy0lh)
+            ),
+            wasm_i32x4_add(
+                wasm_i32x4_dot_i16x8(dx0hl, dy0hl),
+                wasm_i32x4_dot_i16x8(dx0hh, dy0hh)
+            )
+        );
+
+        // Load and process x1
+        v128_t v0_1 = wasm_v128_load(x1->qs);
+        v128_t v0_1l = wasm_v128_and(v0_1, m4b);
+        v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
+        v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
+        v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
+
+        // Load y1 vectors
+        v128_t y1_l = wasm_v128_load(y1->qs);
+        v128_t y1_h = wasm_v128_load(y1->qs + 16);
+
+        // Extend to i16x8 and compute dot products
+        v128_t dx1l = wasm_i16x8_extend_low_i8x16(v0_1ls);
+        v128_t dx1h = wasm_i16x8_extend_high_i8x16(v0_1ls);
+        v128_t dx1hl = wasm_i16x8_extend_low_i8x16(v0_1hs);
+        v128_t dx1hh = wasm_i16x8_extend_high_i8x16(v0_1hs);
+
+        v128_t dy1ll = wasm_i16x8_extend_low_i8x16(y1_l);
+        v128_t dy1lh = wasm_i16x8_extend_high_i8x16(y1_l);
+        v128_t dy1hl = wasm_i16x8_extend_low_i8x16(y1_h);
+        v128_t dy1hh = wasm_i16x8_extend_high_i8x16(y1_h);
+
+        v128_t dp1 = wasm_i32x4_add(
+            wasm_i32x4_add(
+                wasm_i32x4_dot_i16x8(dx1l, dy1ll),
+                wasm_i32x4_dot_i16x8(dx1h, dy1lh)
+            ),
+            wasm_i32x4_add(
+                wasm_i32x4_dot_i16x8(dx1hl, dy1hl),
+                wasm_i32x4_dot_i16x8(dx1hh, dy1hh)
+            )
+        );
+
+        // Accumulate results with scaling
+        float scale0 = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
+        float scale1 = GGML_FP16_TO_FP32(x1->d) * GGML_FP16_TO_FP32(y1->d);
+
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp0), wasm_f32x4_splat(scale0)));
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(dp1), wasm_f32x4_splat(scale1)));
+    }
+
+    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
 #elif defined(__AVX2__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
@@ -2232,21 +2491,22 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = hsum_float_8(acc);
+
 #elif defined(__loongarch_sx)
     // set constants
     const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
     const __m128i off = __lsx_vreplgr2vr_b(8);
 
     // Initialize accumulator with zeros
-    __m128 acc_0 = __lsx_vldi(0);
-    __m128 acc_1 = __lsx_vldi(0);
-    __m128 acc_2 = __lsx_vldi(0);
-    __m128 acc_3 = __lsx_vldi(0);
+    __m128 acc_0 = (__m128)__lsx_vldi(0);
+    __m128 acc_1 = (__m128)__lsx_vldi(0);
+    __m128 acc_2 = (__m128)__lsx_vldi(0);
+    __m128 acc_3 = (__m128)__lsx_vldi(0);
 
     for (; ib + 1 < nb; ib += 2) {
 
         // Compute combined scale for the block 0 and 1
-        const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
+        const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
 
         const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
 
@@ -2264,7 +2524,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
         //_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
 
         // Compute combined scale for the block 2 and 3
-        const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
+        const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
 
         const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
 
@@ -2298,6 +2558,37 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3);
+#elif defined(__VXE__) || defined(__VXE2__)
+    __vector float acc = vec_splats(0.0f);
+
+    const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F);
+    const __vector int8_t  v_s = vec_splats( (const int8_t)0x08);
+
+    for (; ib < nb; ++ib) {
+        const __vector uint8_t v_x = vec_xl(0, x[ib].qs);
+        const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m);
+        const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4);
+
+        const __vector int8_t v_xls = vec_sub(v_xl, v_s);
+        const __vector int8_t v_xhs = vec_sub(v_xh, v_s);
+
+        const __vector int8_t v_yl = vec_xl(0      , y[ib].qs);
+        const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
+
+        const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl);
+        const __vector int16_t v_xylse = vec_mule(v_xls, v_yl);
+        const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh);
+        const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh);
+
+        __vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
+
+        const __vector float v_xy = vec_float(vec_unpackh(v_xy_));
+        const __vector float v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+        acc = vec_madd(v_xy, v_d, acc);
+    }
+
+    sumf = acc[0] + acc[1] + acc[2] + acc[3];
 #endif
     for (; ib < nb; ++ib) {
         int sumi0 = 0;
@@ -2591,6 +2882,35 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = hsum_float_8(acc) + summs;
+#elif defined(__VXE__) || defined(__VXE2__)
+    float summs = 0;
+    float32x4_t acc = vec_splats(0.0f);
+
+    const uint8x16_t v_m = vec_splat_u8(0x0F);
+
+#pragma GCC unroll 4
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        summs += GGML_FP16_TO_FP32(x[ib].m) * GGML_FP16_TO_FP32(y[ib].s);
+
+        const uint8x16_t v_x = vec_xl(0, x[ib].qs);
+        const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);
+        const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);
+
+        const int8x16_t v_yl = vec_xl(0      , y[ib].qs);
+        const int8x16_t v_yh = vec_xl(QK8_1/2, y[ib].qs);
+
+        const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
+        const float32x4_t v_xy = vec_float(v_xy_);
+
+        const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+        acc = vec_madd(v_xy, v_d, acc);
+    }
+
+    sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs;
 #endif
     for (; ib < nb; ++ib) {
         int sumi0 = 0;
@@ -2696,10 +3016,10 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
-#elif defined(__wasm_simd128__)
+#elif defined __wasm_simd128__
     v128_t sumv = wasm_f32x4_splat(0.0f);
 
-    uint32_t qh;
+    uint32_t qh_;
     uint64_t tmp[4];
 
     // TODO: check if unrolling this is better
@@ -2710,12 +3030,12 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
         const v128_t m4b  = wasm_i8x16_splat(0x0F);
 
         // extract the 5th bit
-        memcpy(&qh, x0->qh, sizeof(qh));
+        memcpy(&qh_, x0->qh, sizeof(qh_));
 
-        tmp[0] = table_b2b_1[(qh >>  0) & 0xFF];
-        tmp[1] = table_b2b_1[(qh >>  8) & 0xFF];
-        tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
-        tmp[3] = table_b2b_1[(qh >> 24)       ];
+        tmp[0] = table_b2b_1[(qh_ >>  0) & 0xFF];
+        tmp[1] = table_b2b_1[(qh_ >>  8) & 0xFF];
+        tmp[2] = table_b2b_1[(qh_ >> 16) & 0xFF];
+        tmp[3] = table_b2b_1[(qh_ >> 24)       ];
 
         const v128_t qhl = wasm_v128_load(tmp + 0);
         const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3057,12 +3377,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
-#elif defined(__wasm_simd128__)
+#elif defined __wasm_simd128__
     v128_t sumv = wasm_f32x4_splat(0.0f);
 
     float summs = 0.0f;
 
-    uint32_t qh;
+    uint32_t qh_;
     uint64_t tmp[4];
 
     // TODO: check if unrolling this is better
@@ -3075,12 +3395,12 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
         const v128_t m4b = wasm_i8x16_splat(0x0F);
 
         // extract the 5th bit
-        memcpy(&qh, x0->qh, sizeof(qh));
+        memcpy(&qh_, x0->qh, sizeof(qh_));
 
-        tmp[0] = table_b2b_0[(qh >>  0) & 0xFF];
-        tmp[1] = table_b2b_0[(qh >>  8) & 0xFF];
-        tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
-        tmp[3] = table_b2b_0[(qh >> 24)       ];
+        tmp[0] = table_b2b_0[(qh_ >>  0) & 0xFF];
+        tmp[1] = table_b2b_0[(qh_ >>  8) & 0xFF];
+        tmp[2] = table_b2b_0[(qh_ >> 16) & 0xFF];
+        tmp[3] = table_b2b_0[(qh_ >> 24)       ];
 
         const v128_t qhl = wasm_v128_load(tmp + 0);
         const v128_t qhh = wasm_v128_load(tmp + 2);
@@ -3573,6 +3893,45 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined __wasm_simd128__
+    v128_t sumv = wasm_f32x4_splat(0.0f);
+
+    for (; ib < nb; ++ib) {
+        const block_q8_0 * restrict x0 = &x[ib];
+        const block_q8_0 * restrict y0 = &y[ib];
+
+        const v128_t x0_0 = wasm_v128_load(x0->qs);
+        const v128_t x0_1 = wasm_v128_load(x0->qs + 16);
+        const v128_t y0_0 = wasm_v128_load(y0->qs);
+        const v128_t y0_1 = wasm_v128_load(y0->qs + 16);
+
+        // Extend 8-bit to 16-bit
+        const v128_t x0_0l = wasm_i16x8_extend_low_i8x16(x0_0);
+        const v128_t x0_0h = wasm_i16x8_extend_high_i8x16(x0_0);
+        const v128_t x0_1l = wasm_i16x8_extend_low_i8x16(x0_1);
+        const v128_t x0_1h = wasm_i16x8_extend_high_i8x16(x0_1);
+
+        const v128_t y0_0l = wasm_i16x8_extend_low_i8x16(y0_0);
+        const v128_t y0_0h = wasm_i16x8_extend_high_i8x16(y0_0);
+        const v128_t y0_1l = wasm_i16x8_extend_low_i8x16(y0_1);
+        const v128_t y0_1h = wasm_i16x8_extend_high_i8x16(y0_1);
+
+        // Compute dot products
+        const v128_t dx0_0 = wasm_i32x4_dot_i16x8(x0_0l, y0_0l);
+        const v128_t dx0_1 = wasm_i32x4_dot_i16x8(x0_0h, y0_0h);
+        const v128_t dx1_0 = wasm_i32x4_dot_i16x8(x0_1l, y0_1l);
+        const v128_t dx1_1 = wasm_i32x4_dot_i16x8(x0_1h, y0_1h);
+
+        // Sum all dot products
+        const v128_t sum_dots = wasm_i32x4_add(wasm_i32x4_add(dx0_0, dx0_1), wasm_i32x4_add(dx1_0, dx1_1));
+
+        // Convert to float and accumulate
+        const float scale = GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d);
+        sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(sum_dots), wasm_f32x4_splat(scale)));
+    }
+
+    sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
+           wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3);
 #elif defined(__AVX2__)
     // Initialize accumulator with zeros
     __m256 acc = _mm256_setzero_ps();
@@ -3686,6 +4045,27 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
     }
 
     sumf = hsum_float_8(acc);
+#elif defined(__VXE__) || defined(__VXE2__)
+    __vector float acc = vec_splats(0.0f);
+
+#pragma GCC unroll 8
+    for (; ib < nb; ++ib) {
+        __builtin_prefetch(x[ib].qs, 0, 1);
+        __builtin_prefetch(y[ib].qs, 0, 1);
+
+        const int8x16_t v_xl = vec_xl(0      , x[ib].qs);
+        const int8x16_t v_xh = vec_xl(QK8_0/2, x[ib].qs);
+        const int8x16_t v_yl = vec_xl(0      , y[ib].qs);
+        const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
+
+        const int32x4_t v_xy_ = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
+        const float32x4_t v_xy = vec_float(v_xy_);
+        const float32x4_t v_d = vec_splats(GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d));
+
+        acc = vec_madd(v_xy, v_d, acc);
+    }
+
+    sumf = acc[0] + acc[1] + acc[2] + acc[3];
 #endif
     for (; ib < nb; ++ib) {
         int sumi = 0;
@@ -4447,6 +4827,106 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     *s = hsum_float_8(acc);
 
+#elif defined __wasm_simd128__
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * q2 = x[i].qs;
+        const int8_t * q8 = y[i].qs;
+        const uint8_t * sc = x[i].scales;
+
+        // Vectorized summs calculation
+        v128_t summs_vec = wasm_i32x4_splat(0);
+        {
+            v128_t sc_vec = wasm_v128_load(sc);
+            v128_t sc_upper = wasm_u8x16_shr(sc_vec, 4);
+
+            v128_t sc_low = wasm_u16x8_extend_low_u8x16(sc_upper);
+            v128_t sc_high = wasm_u16x8_extend_high_u8x16(sc_upper);
+
+            v128_t bsums1 = wasm_v128_load(&y[i].bsums[0]);
+            v128_t bsums2 = wasm_v128_load(&y[i].bsums[8]);
+
+            summs_vec = wasm_i32x4_add(
+                wasm_i32x4_add(wasm_i32x4_dot_i16x8(sc_low, bsums1),
+                               wasm_i32x4_dot_i16x8(sc_high, bsums2)),
+                summs_vec
+            );
+
+            summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 2, 3, 0, 1));
+            summs_vec = wasm_i32x4_add(summs_vec, wasm_i32x4_shuffle(summs_vec, summs_vec, 1, 0, 3, 2));
+        }
+        int32_t summs = wasm_i32x4_extract_lane(summs_vec, 0);
+
+        // Vectorized isum calculation
+        int32_t isum = 0;
+        const uint8_t * sc_ptr = sc;
+        const int k_iters = QK_K/128;
+
+        for (int k = 0; k < k_iters; ++k) {
+            v128_t isum_vec = wasm_i32x4_splat(0);
+            int shift = 0;
+
+            for (int j = 0; j < 4; ++j) {
+                const int d0 = (sc_ptr[0] & 0xF);
+                const int d1 = (sc_ptr[1] & 0xF);
+                sc_ptr += 2;
+
+                // Process first 16 elements
+                v128_t q2_0 = wasm_v128_load(q2);
+                v128_t q8_0 = wasm_v128_load(q8);
+                v128_t q2_shift_0 = wasm_u8x16_shr(q2_0, shift);
+                v128_t q2_bits_0 = wasm_v128_and(q2_shift_0, wasm_i8x16_splat(0x03));
+
+                // Process next 16 elements
+                v128_t q2_1 = wasm_v128_load(q2 + 16);
+                v128_t q8_1 = wasm_v128_load(q8 + 16);
+                v128_t q2_shift_1 = wasm_u8x16_shr(q2_1, shift);
+                v128_t q2_bits_1 = wasm_v128_and(q2_shift_1, wasm_i8x16_splat(0x03));
+
+                // Calculate dot products
+                v128_t p0 = wasm_i32x4_dot_i16x8(
+                    wasm_i16x8_extend_low_i8x16(q8_0),
+                    wasm_i16x8_extend_low_i8x16(q2_bits_0)
+                );
+                v128_t p1 = wasm_i32x4_dot_i16x8(
+                    wasm_i16x8_extend_high_i8x16(q8_0),
+                    wasm_i16x8_extend_high_i8x16(q2_bits_0)
+                );
+                v128_t p2 = wasm_i32x4_dot_i16x8(
+                    wasm_i16x8_extend_low_i8x16(q8_1),
+                    wasm_i16x8_extend_low_i8x16(q2_bits_1)
+                );
+                v128_t p3 = wasm_i32x4_dot_i16x8(
+                    wasm_i16x8_extend_high_i8x16(q8_1),
+                    wasm_i16x8_extend_high_i8x16(q2_bits_1)
+                );
+
+                // Accumulate scaled results
+                v128_t scaled = wasm_i32x4_add(
+                    wasm_i32x4_mul(wasm_i32x4_add(p0, p1), wasm_i32x4_splat(d0)),
+                    wasm_i32x4_mul(wasm_i32x4_add(p2, p3), wasm_i32x4_splat(d1))
+                );
+
+                isum_vec = wasm_i32x4_add(isum_vec, scaled);
+                q8 += 32;
+                shift += 2;
+            }
+            q2 += 32;
+
+            // Horizontal sum of isum_vec
+            isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 2, 3, 0, 1));
+            isum_vec = wasm_i32x4_add(isum_vec, wasm_i32x4_shuffle(isum_vec, isum_vec, 1, 0, 3, 2));
+            isum += wasm_i32x4_extract_lane(isum_vec, 0);
+        }
+
+        const float dall = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const float dmin = GGML_FP16_TO_FP32(x[i].dmin) * y[i].d;
+        sumf += dall * isum - dmin * summs;
+    }
+
+    *s = sumf;
+
 #elif defined __riscv_v_intrinsic
 
     float sumf = 0;
@@ -4666,9 +5146,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
 #elif defined __loongarch_asx
 
-    const __m256i m3 = __lasx_xvreplgr2vr_b(3);
-    const __m128i m4 = __lsx_vreplgr2vr_b(0xF);
-
     __m256 acc = (__m256)__lasx_xvldi(0);
 
     for (int i = 0; i < nb; ++i) {
@@ -4679,18 +5156,15 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const uint8_t * restrict q2 = x[i].qs;
         const int8_t  * restrict q8 = y[i].qs;
 
-        const __m128i mins_and_scales = __lsx_vld((const __m128i*)x[i].scales, 0);
-        const __m128i scales8 = __lsx_vand_v(mins_and_scales, m4);
-        const __m128i mins8 = __lsx_vand_v(__lsx_vsrli_h(mins_and_scales, 4), m4);
-        const __m256i mins = lasx_ext8_16(mins8);
+        const __m128i mins_and_scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
+        const __m128i scales128 = __lsx_vandi_b(mins_and_scales128, 0xf);
+        const __m256i mins = lasx_ext8_16(__lsx_vsrli_b(mins_and_scales128, 4));
         const __m256i prod = lasx_madd_h(mins, __lasx_xvld((const __m256i*)y[i].bsums, 0));
 
         acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(dmin), __lasx_xvffint_s_w(prod), acc);
 
-        const __m256i all_scales = lasx_ext8_16(scales8);
-        const __m128i l_scales = lasx_extracti128(all_scales, 0);
-        const __m128i h_scales = lasx_extracti128(all_scales, 1);
-        const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
+        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
+        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
 
         __m256i sumi = __lasx_xvldi(0);
 
@@ -4703,20 +5177,20 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * r
             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
 
-            const __m256i q2_0 = __lasx_xvand_v(q2bits, m3);
-            const __m256i q2_1 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 2), m3);
-            const __m256i q2_2 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 4), m3);
-            const __m256i q2_3 = __lasx_xvand_v(__lasx_xvsrli_h(q2bits, 6), m3);
+            const __m256i q2_0 = __lasx_xvandi_b(q2bits, 3);
+            const __m256i q2_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 2), 3);
+            const __m256i q2_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q2bits, 4), 3);
+            const __m256i q2_3 = __lasx_xvsrli_b(q2bits, 6);
 
-            __m256i p0 = lasx_maddubs_h(q2_0, q8_0);
-            __m256i p1 = lasx_maddubs_h(q2_1, q8_1);
-            __m256i p2 = lasx_maddubs_h(q2_2, q8_2);
-            __m256i p3 = lasx_maddubs_h(q2_3, q8_3);
+            __m256i p0 = lasx_madd_h_b(q2_0, q8_0);
+            __m256i p1 = lasx_madd_h_b(q2_1, q8_1);
+            __m256i p2 = lasx_madd_h_b(q2_2, q8_2);
+            __m256i p3 = lasx_madd_h_b(q2_3, q8_3);
 
-            p0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(0)), p0);
-            p1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(1)), p1);
-            p2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(2)), p2);
-            p3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(3)), p3);
+            p0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p0);
+            p1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p1);
+            p2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p2);
+            p3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p3);
 
             p0 = __lasx_xvadd_w(p0, p1);
             p2 = __lasx_xvadd_w(p2, p3);
@@ -4789,7 +5263,182 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     const int nb = n / QK_K;
 
-#ifdef __ARM_NEON
+#if defined(__ARM_FEATURE_SVE)
+
+    uint32_t aux[3];
+    uint32_t utmp[4];
+
+    const int8_t m32 = 32;
+    const int vector_length = svcntb()*8;
+    const svuint8_t m3b_sv = svdup_n_u8(0x3);
+    const svint32_t vzero_sv = svdup_n_s32(0);
+
+    const svuint8_t m0_sv = svdup_n_u8(1);
+    const svuint8_t m1_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 1);
+    const svuint8_t m2_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 2);
+    const svuint8_t m3_sv = svlsl_n_u8_x(svptrue_b8(), m0_sv, 3);
+
+    float sum = 0;
+
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict q3_sv = x[i].qs;
+        const uint8_t * restrict qh_sv = x[i].hmask;
+        const int8_t  * restrict q8_sv = y[i].qs;
+
+        // Set up scales
+        memcpy(aux, x[i].scales, 12);
+        utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
+        utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
+        utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
+        utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
+
+        int8_t * scale = (int8_t *)utmp;
+
+        for (int j = 0; j < 16; ++j) scale[j] -= m32;
+
+        switch (vector_length) {
+            case 128:
+                {
+                    svuint8_t qhbits_sv_1 = svld1_u8(svptrue_b8(), qh_sv);
+                    svuint8_t qhbits_sv_2 = svld1_u8(svptrue_b8(), qh_sv+16);
+                    svuint8_t q3h_sv;
+
+                    svint32_t sumi1_1 = svdup_n_s32(0);
+                    svint8_t q3bytes_sv;
+
+                    for (int j = 0; j < QK_K/128; ++j) {
+
+                        const svuint8_t q3bits_sv = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
+                        const svuint8_t q3bits_sv_1 = svld1_u8(svptrue_b8(), q3_sv); q3_sv += 16;
+                        svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+                        svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_1), 2);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m0_sv, qhbits_sv_2), 2);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q3bits_sv_1, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
+
+                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_1), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m1_sv, qhbits_sv_2), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
+
+
+                        scale += 4;
+                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                        q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_1);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[0]));
+
+                        q3h_sv = svbic_u8_x(svptrue_b8(), m2_sv, qhbits_sv_2);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[1]));
+
+
+                        q8bytes_1_sv_1 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+                        q8bytes_1_sv_2 = svld1_s8(svptrue_b8(), q8_sv); q8_sv += 16;
+
+                        q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_1), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), svdup_n_s32((int32_t)scale[2]));
+
+                        q3h_sv = svlsr_n_u8_x(svptrue_b8(), svbic_u8_x(svptrue_b8(), m3_sv, qhbits_sv_2), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_b8(), svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), svlsr_n_u8_x(svptrue_b8(), q3bits_sv_1, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        sumi1_1 = svmla_s32_m(svptrue_b32(), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), svdup_n_s32((int32_t)scale[3]));
+
+                        if (j == 0) {
+                            qhbits_sv_1 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_1, 4);
+                            qhbits_sv_2 = svlsr_n_u8_x(svptrue_b8(), qhbits_sv_2, 4);
+                        }
+
+                        scale += 4;
+                    }
+
+                    sum += d * (svaddv_s32(svptrue_b32(), sumi1_1));
+                } break;
+            case 256:
+            case 512:
+                {
+                    svuint8_t qhbits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), qh_sv);
+                    svuint8_t q3h_sv;
+
+                    svint32_t sumi1_1 = svdup_n_s32(0);
+                    svint8_t q3bytes_sv;
+
+                    for (int j = 0; j < QK_K/128; ++j) {
+
+                        const svuint8_t q3bits_sv = svld1_u8(svptrue_pat_b8(SV_VL32), q3_sv); q3_sv += 32;
+                        svint8_t q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+                        svint8_t q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m0_sv, qhbits_sv), 2);
+                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), q3bits_sv, m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+
+                        svint32_t scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
+                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
+
+                        q3h_sv = svlsl_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m1_sv, qhbits_sv), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 2), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
+                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
+
+                        scale += 4;
+                        q8bytes_1_sv_1 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+                        q8bytes_1_sv_2 = svld1_s8(svptrue_pat_b8(SV_VL32), q8_sv); q8_sv += 32;
+
+                        q3h_sv = svbic_u8_x(svptrue_pat_b8(SV_VL32), m2_sv, qhbits_sv);
+                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 4), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[0]), svdup_n_s32((int32_t)scale[1]));
+                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_1), scale_1);
+
+                        q3h_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), svbic_u8_x(svptrue_pat_b8(SV_VL32), m3_sv, qhbits_sv), 1);
+                        q3bytes_sv = svsub_s8_x(svptrue_pat_b8(SV_VL32), svreinterpret_s8_u8(svand_u8_m(svptrue_pat_b8(SV_VL32), svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q3bits_sv, 6), m3b_sv)), svreinterpret_s8_u8(q3h_sv));
+
+                        scale_1 = svsel_s32(svptrue_pat_b32(SV_VL4), svdup_n_s32((int32_t)scale[2]), svdup_n_s32((int32_t)scale[3]));
+                        sumi1_1 = svmla_s32_m(svptrue_pat_b32(SV_VL8), sumi1_1, svdot_s32(vzero_sv, q3bytes_sv, q8bytes_1_sv_2), scale_1);
+
+                        if (j == 0) {
+                            qhbits_sv = svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), qhbits_sv, 4);
+                        }
+
+                        scale += 4;
+                    }
+
+                    sum += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), sumi1_1));
+                } break;
+            default:
+                assert(false && "Unsupported vector length");
+                break;
+        }
+    }
+    *s = sum;
+
+#elif __ARM_NEON
 
     uint32_t aux[3];
     uint32_t utmp[4];
@@ -5129,6 +5778,94 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     *s = hsum_float_8(acc);
 
+#elif defined __wasm_simd128__
+    int8_t  aux8[QK_K];
+    float   sums[8] = {0};
+    uint32_t auxs[4];
+
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+        const uint8_t * restrict q3 = x[i].qs;
+        const uint8_t * restrict hm = x[i].hmask;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Process blocks with SIMD
+        int8_t * a = aux8;
+        uint8_t m = 1;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int shift = 0; shift <= 6; shift += 2) {
+                v128_t v_m = wasm_i8x16_splat(m);
+                for (int l = 0; l < 32; l += 16) {
+                    v128_t v_q3 = wasm_v128_load(q3 + l);
+                    v128_t v_shift = wasm_i8x16_shr(v_q3, shift);
+                    v128_t v_low2 = wasm_v128_and(v_shift, wasm_i8x16_splat(0x03));
+
+                    v128_t v_hm = wasm_v128_load(hm + l);
+                    v128_t v_mask = wasm_v128_and(v_hm, v_m);
+                    v_mask = wasm_i8x16_ne(v_mask, wasm_i8x16_splat(0));
+
+                    v_low2 = wasm_i8x16_sub(v_low2, wasm_v128_and(wasm_i8x16_splat(4), wasm_v128_not(v_mask)));
+                    wasm_v128_store(a + l, v_low2);
+                }
+                a += 32;
+                m <<= 1;
+            }
+            q3 += 32;
+        }
+
+        // Extract scales
+        memcpy(auxs, x[i].scales, 12);
+        uint32_t tmp = auxs[2];
+        auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
+        auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
+        auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
+        auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
+        const int8_t * scales = (const int8_t *)auxs;
+
+        // SIMD dot product with register accumulators
+        v128_t v_acc0 = wasm_i32x4_splat(0);
+        v128_t v_acc1 = wasm_i32x4_splat(0);
+        a = aux8;
+        for (int j = 0; j < QK_K/16; ++j) {
+            const v128_t v_scale = wasm_i16x8_splat(scales[j] - 32);
+
+            // Process 16 elements per iteration
+            for (int k = 0; k < 2; ++k) {
+                const v128_t v_q8 = wasm_i16x8_load8x8(q8);
+                const v128_t v_a = wasm_i16x8_load8x8(a);
+
+                v128_t v_prod = wasm_i16x8_mul(v_q8, v_a);
+                v_prod = wasm_i16x8_mul(v_prod, v_scale);
+
+                v_acc0 = wasm_i32x4_add(v_acc0, wasm_i32x4_extend_low_i16x8(v_prod));
+                v_acc1 = wasm_i32x4_add(v_acc1, wasm_i32x4_extend_high_i16x8(v_prod));
+
+                q8 += 8;
+                a += 8;
+            }
+        }
+
+        // Accumulate results
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        const v128_t v_d = wasm_f32x4_splat(d);
+        v128_t v_sum = wasm_f32x4_add(
+            wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc0), v_d),
+            wasm_f32x4_mul(wasm_f32x4_convert_i32x4(v_acc1), v_d)
+        );
+
+        // Accumulate into sums vector
+        wasm_v128_store(sums, wasm_f32x4_add(wasm_v128_load(sums), v_sum));
+    }
+
+    // Horizontal sum
+    v128_t v_sum = wasm_f32x4_add(wasm_v128_load(sums), wasm_v128_load(sums + 4));
+    sumf = wasm_f32x4_extract_lane(v_sum, 0) +
+           wasm_f32x4_extract_lane(v_sum, 1) +
+           wasm_f32x4_extract_lane(v_sum, 2) +
+           wasm_f32x4_extract_lane(v_sum, 3);
+
+    *s = sumf;
+
 #elif defined __riscv_v_intrinsic
 
     uint32_t aux[3];
@@ -5384,8 +6121,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
 #elif defined __loongarch_asx
 
-    const __m256i m3 = __lasx_xvreplgr2vr_b(3);
-    const __m256i mone = __lasx_xvreplgr2vr_b(1);
     const __m128i m32 = __lsx_vreplgr2vr_b(32);
 
     __m256 acc = (__m256)__lasx_xvldi(0);
@@ -5405,10 +6140,9 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
                 (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
                 (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
         scales128 = __lsx_vsub_b(scales128, m32);
-        const __m256i all_scales = lasx_ext8_16(scales128);
-        const __m128i l_scales = lasx_extracti128(all_scales, 0);
-        const __m128i h_scales = lasx_extracti128(all_scales, 1);
-        const __m256i scales[2] = {lasx_insertf128(l_scales, l_scales), lasx_insertf128(h_scales, h_scales)};
+
+        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
+        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
 
         // high bit
         const __m256i hbits = __lasx_xvld((const __m256i*)x[i].hmask, 0);
@@ -5416,35 +6150,23 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         // integer accumulator
         __m256i sumi = __lasx_xvldi(0);
 
-        int bit = 0;
-        int is  = 0;
-        __m256i xvbit;
-
-
         for (int j = 0; j < QK_K/128; ++j) {
             // load low 2 bits
             const __m256i q3bits = __lasx_xvld((const __m256i*)q3, 0); q3 += 32;
 
-            xvbit = __lasx_xvreplgr2vr_h(bit);
             // prepare low and high bits
-            const __m256i q3l_0 = __lasx_xvand_v(q3bits, m3);
-            const __m256i q3h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
-            ++bit;
-
-            xvbit = __lasx_xvreplgr2vr_h(bit);
-            const __m256i q3l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 2), m3);
-            const __m256i q3h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
-            ++bit;
-
-            xvbit = __lasx_xvreplgr2vr_h(bit);
-            const __m256i q3l_2 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 4), m3);
-            const __m256i q3h_2 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
-            ++bit;
-
-            xvbit = __lasx_xvreplgr2vr_h(bit);
-            const __m256i q3l_3 = __lasx_xvand_v(__lasx_xvsrli_h(q3bits, 6), m3);
-            const __m256i q3h_3 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvandn_v(hbits, __lasx_xvsll_h(mone, xvbit)), xvbit), 2);
-            ++bit;
+            const __m256i q3l_0 = __lasx_xvandi_b(q3bits, 3);
+            const __m256i q3l_1 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 2), 3);
+            const __m256i q3l_2 = __lasx_xvandi_b(__lasx_xvsrli_b(q3bits, 4), 3);
+            const __m256i q3l_3 = __lasx_xvsrli_b(q3bits, 6);
+            const __m256i q3h_0 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 0), 0), 2);
+            const __m256i q3h_1 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 1), 0), 2);
+            const __m256i q3h_2 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 2), 0), 2);
+            const __m256i q3h_3 = __lasx_xvslli_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 4 * j + 3), 0), 2);
+            const __m256i q3_0 = __lasx_xvor_v(q3h_0, q3l_0);
+            const __m256i q3_1 = __lasx_xvor_v(q3h_1, q3l_1);
+            const __m256i q3_2 = __lasx_xvor_v(q3h_2, q3l_2);
+            const __m256i q3_3 = __lasx_xvor_v(q3h_3, q3l_3);
 
             // load Q8 quants
             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
@@ -5452,29 +6174,16 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
 
-            // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use lasx_maddubs_h,
-            // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
-            // and 2 if the high bit was set)
-            __m256i q8s_0 = lasx_maddubs_h(q3h_0, q8_0);
-            __m256i q8s_1 = lasx_maddubs_h(q3h_1, q8_1);
-            __m256i q8s_2 = lasx_maddubs_h(q3h_2, q8_2);
-            __m256i q8s_3 = lasx_maddubs_h(q3h_3, q8_3);
-
-            __m256i p16_0 = lasx_maddubs_h(q3l_0, q8_0);
-            __m256i p16_1 = lasx_maddubs_h(q3l_1, q8_1);
-            __m256i p16_2 = lasx_maddubs_h(q3l_2, q8_2);
-            __m256i p16_3 = lasx_maddubs_h(q3l_3, q8_3);
-
-            p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
-            p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
-            p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
-            p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
+            __m256i p16_0 = lasx_madd_h_b(q8_0, q3_0);
+            __m256i p16_1 = lasx_madd_h_b(q8_1, q3_1);
+            __m256i p16_2 = lasx_madd_h_b(q8_2, q3_2);
+            __m256i p16_3 = lasx_madd_h_b(q8_3, q3_3);
 
             // multiply with scales
-            p16_0 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
-            p16_1 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
-            p16_2 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
-            p16_3 = lasx_madd_h(lasx_shuffle_b(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
+            p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
+            p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
+            p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
+            p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
 
             // accumulate
             p16_0 = __lasx_xvadd_w(p16_0, p16_1);
@@ -5482,7 +6191,7 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, size_t bs, const void * r
             sumi  = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_2));
         }
         // multiply with block scale and accumulate
-        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);//FIXME
+        acc = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(d), __lasx_xvffint_s_w(sumi), acc);
     }
 
     *s = hsum_float_8(acc);
@@ -5573,7 +6282,88 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     uint32_t utmp[4];
 
-#ifdef __ARM_NEON
+#ifdef __ARM_FEATURE_SVE
+    float sumf = 0;
+    for (int i = 0; i < nb; ++i) {
+
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
+
+        memcpy(utmp, x[i].scales, K_SCALE_SIZE);
+
+        uint32x2_t mins8 = { 0 };
+        mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0);
+        mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1);
+
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[0] &= kmask1;
+
+        const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
+        const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
+                                         vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
+        sumf -= dmin * vaddvq_s32(prod);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        const int vector_length = ggml_cpu_get_sve_cnt()*8;
+        const svuint8_t m4b = svdup_n_u8(0xf);
+        const svint32_t mzero = svdup_n_s32(0);
+        svint32_t sumi1 = svdup_n_s32(0);
+        svint32_t sumi1_1 = svdup_n_s32(0);
+        svint32_t sumi1_2 = svdup_n_s32(0);
+        svint32_t sumi2 = svdup_n_s32(0);
+        svint32_t sumi2_1 = svdup_n_s32(0);
+        svint32_t sumi2_2 = svdup_n_s32(0);
+        switch (vector_length) {
+            case 128:
+                {
+                    for (int j = 0; j < QK_K/64; ++j) {
+                        svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), m4b));
+                        svint8_t q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi1_1 = svmla_n_s32_x(svptrue_b32(), sumi1_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+                        q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), m4b));
+                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi1_2 = svmla_n_s32_x(svptrue_b32(), sumi1_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+
+                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4), 4));
+                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi2_1 = svmla_n_s32_x(svptrue_b32(), sumi2_1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_b8(), svld1_u8(svptrue_b8(), q4+16), 4));
+                        q8bytes = svld1_s8(svptrue_b8(), q8); q8 += 16;
+                        sumi2_2 = svmla_n_s32_x(svptrue_b32(), sumi2_2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+                        q4 += 32;
+                    }
+                    sumi1 = svadd_s32_x(svptrue_b32(), sumi1_1, sumi1_2);
+                    sumi2 = svadd_s32_x(svptrue_b32(), sumi2_1, sumi2_2);
+                    sumf += d * (svaddv_s32(svptrue_b32(), svadd_s32_x(svptrue_b32(), sumi1, sumi2)));
+                } break;
+            case 256:
+            case 512:
+                {
+                    for (int j = 0; j < QK_K/64; ++j) {
+                        const svuint8_t q4bits  = svld1_u8(svptrue_pat_b8(SV_VL32), q4); q4 += 32;
+                        svint8_t q4bytes = svreinterpret_s8_u8(svand_u8_x(svptrue_pat_b8(SV_VL32), q4bits, m4b));
+                        svint8_t q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
+                        sumi1 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi1, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+0]);
+
+                        q4bytes = svreinterpret_s8_u8(svlsr_n_u8_x(svptrue_pat_b8(SV_VL32), q4bits, 4));
+                        q8bytes = svld1_s8(svptrue_pat_b8(SV_VL32), q8); q8 += 32;
+                        sumi2 = svmla_n_s32_x(svptrue_pat_b32(SV_VL8), sumi2, svdot_s32(mzero, q4bytes, q8bytes), scales[2*j+1]);
+                    }
+                    sumf += d * (svaddv_s32(svptrue_pat_b32(SV_VL8), svadd_s32_x(svptrue_pat_b32(SV_VL8), sumi1, sumi2)));
+                } break;
+            default:
+                assert(false && "Unsupported vector length");
+                break;
+        }
+    }
+    *s = sumf;
+#elif defined __ARM_NEON
     const uint8x16_t m4b = vdupq_n_u8(0xf);
     const int32x4_t mzero = vdupq_n_s32(0);
 
@@ -5636,6 +6426,107 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     *s = sumf;
 
+#elif defined __wasm_simd128__
+    const uint8_t * scales = (const uint8_t*)&utmp[0];
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Corrected sign
+
+        const uint8_t * restrict q4 = x[i].qs;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Process scales and mins
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        // Sum mins * q8sums
+        int32_t sumi = 0;
+        const int16_t * restrict q8sums = y[i].bsums;
+        const uint8_t * m = (const uint8_t *)&utmp[2];
+        for (int j = 0; j < 16; j += 2) {
+            sumi += (q8sums[j] + q8sums[j+1]) * m[j/2];
+        }
+        sumf -= dmin * sumi;
+
+        int32_t sumi1 = 0;
+        int32_t sumi2 = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            // Load 64 4-bit weights (32 bytes)
+            const v128_t q4x0 = wasm_v128_load(q4);
+            const v128_t q4x1 = wasm_v128_load(q4 + 16);
+            q4 += 32;
+
+            // Split into low/high nibbles
+            const v128_t q4l0 = wasm_v128_and(q4x0, wasm_i8x16_splat(0x0F));
+            const v128_t q4h0 = wasm_u8x16_shr(q4x0, 4);
+            const v128_t q4l1 = wasm_v128_and(q4x1, wasm_i8x16_splat(0x0F));
+            const v128_t q4h1 = wasm_u8x16_shr(q4x1, 4);
+
+            // Load 64 8-bit values (64 bytes)
+            const v128_t q8x0 = wasm_v128_load(q8);
+            const v128_t q8x1 = wasm_v128_load(q8 + 16);
+            const v128_t q8x2 = wasm_v128_load(q8 + 32);
+            const v128_t q8x3 = wasm_v128_load(q8 + 48);
+            q8 += 64;
+
+            // Low nibble products
+            v128_t vacc1 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q4l0),
+                wasm_i16x8_extend_low_i8x16(q8x0)
+            );
+            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q4l0),
+                wasm_i16x8_extend_high_i8x16(q8x0)
+            ));
+            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q4l1),
+                wasm_i16x8_extend_low_i8x16(q8x1)
+            ));
+            vacc1 = wasm_i32x4_add(vacc1, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q4l1),
+                wasm_i16x8_extend_high_i8x16(q8x1)
+            ));
+
+            // High nibble products
+            v128_t vacc2 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q4h0),
+                wasm_i16x8_extend_low_i8x16(q8x2)
+            );
+            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q4h0),
+                wasm_i16x8_extend_high_i8x16(q8x2)
+            ));
+            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q4h1),
+                wasm_i16x8_extend_low_i8x16(q8x3)
+            ));
+            vacc2 = wasm_i32x4_add(vacc2, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q4h1),
+                wasm_i16x8_extend_high_i8x16(q8x3)
+            ));
+
+            // Accumulate scaled results
+            int32_t vacc1_sum = wasm_i32x4_extract_lane(vacc1, 0) + wasm_i32x4_extract_lane(vacc1, 1) +
+                                wasm_i32x4_extract_lane(vacc1, 2) + wasm_i32x4_extract_lane(vacc1, 3);
+            sumi1 += vacc1_sum * scales[2*j];
+
+            int32_t vacc2_sum = wasm_i32x4_extract_lane(vacc2, 0) + wasm_i32x4_extract_lane(vacc2, 1) +
+                                wasm_i32x4_extract_lane(vacc2, 2) + wasm_i32x4_extract_lane(vacc2, 3);
+            sumi2 += vacc2_sum * scales[2*j+1];
+        }
+
+        sumf += d * (sumi1 + sumi2);
+    }
+
+    *s = sumf;
+
 #elif defined __AVX2__
 
     const __m256i m4 = _mm256_set1_epi8(0xF);
@@ -5993,11 +6884,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     *s = vec_extract(vsumf0, 0);
 
 #elif defined __loongarch_asx
-    GGML_UNUSED(kmask1);
-    GGML_UNUSED(kmask2);
-    GGML_UNUSED(kmask3);
-
-    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
 
     __m256 acc = (__m256)__lasx_xvldi(0);
     __m128 acc_m = (__m128)__lsx_vldi(0);
@@ -6017,33 +6903,34 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const uint8_t * restrict q4 = x[i].qs;
         const int8_t  * restrict q8 = y[i].qs;
 
-        const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+        const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
+        const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
 
         const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
         const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
-        const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
+        const __m128i prod = lsx_madd_h(mins128, q8s);
         acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
 
-        const __m128i sc128  = lasx_extracti128(mins_and_scales, 0);
-        const __m256i scales = lasx_insertf128(sc128, sc128);
+        const __m256i scales = lasx_insertf128(scales128, scales128);
 
         __m256i sumi = __lasx_xvldi(0);
 
         for (int j = 0; j < QK_K/64; ++j) {
 
-            const __m256i scale_l = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
-            const __m256i scale_h = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+            const __m256i scale_l = lasx_xvrepl128vei_h(scales, 2 * j + 0);
+            const __m256i scale_h = lasx_xvrepl128vei_h(scales, 2 * j + 1);
 
             const __m256i q4bits = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
-            const __m256i q4l = __lasx_xvand_v(q4bits, m4);
-            const __m256i q4h = __lasx_xvand_v(__lasx_xvsrli_h(q4bits, 4), m4);
+            const __m256i q4l = __lasx_xvandi_b(q4bits, 0xf);
+            const __m256i q4h = __lasx_xvsrli_b(q4bits, 4);
 
             const __m256i q8l = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
-            __m256i p16l = lasx_maddubs_h(q4l, q8l);
+            __m256i p16l = lasx_madd_h_b(q4l, q8l);
             p16l = lasx_madd_h(scale_l, p16l);
 
             const __m256i q8h = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
-            __m256i p16h = lasx_maddubs_h(q4h, q8h);
+            __m256i p16h = lasx_madd_h_b(q4h, q8h);
             p16h = lasx_madd_h(scale_h, p16h);
             const __m256i sumj = __lasx_xvadd_w(p16l, p16h);
 
@@ -6060,9 +6947,78 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
 
 
-    ft_union fi;
-    fi.i = __lsx_vpickve2gr_w(acc_m, 0);
-    *s = hsum_float_8(acc) + fi.f ;
+    *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
+#elif defined(__VXE__) || defined(__VXE2__)
+    const uint8x16_t v_lm = vec_splat_u8(0x0F);
+    const int32x4_t v_z = vec_splat_s32(0);
+
+    uint8x16_t v_x[2];
+    int8x16_t  v_xl[2];
+    int8x16_t  v_y[2];
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
+        const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
+        const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
+
+        memcpy(utmp, x[i].scales, 12);
+
+        uint32x4_t v_mins8 = { 0 };
+        v_mins8 = vec_insert(utmp[1] & kmask1, v_mins8, 0);
+        v_mins8 = vec_insert(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), v_mins8, 1);
+
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[0] &= kmask1;
+
+        const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8);
+
+        const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh);
+        const int32x4_t v_minse = vec_mule(v_ysums, v_minsh);
+        const int32x4_t v_mins = v_minso + v_minse;
+        sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]);
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+        const uint8_t * restrict x0 = x[i].qs;
+        const int8_t  * restrict y0 = y[i].qs;
+
+        int32_t sumi1 = 0;
+        int32_t sumi2 = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            v_x[0] = vec_xl(0 , x0);
+            v_x[1] = vec_xl(16, x0);
+            x0 += 32;
+
+            v_y[0] = vec_xl(0 , y0);
+            v_y[1] = vec_xl(16, y0);
+            y0 += 32;
+
+            v_xl[0] = (int8x16_t)vec_and(v_x[0], v_lm);
+            v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);
+
+            const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
+            sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0];
+
+            v_y[0] = vec_xl(0 , y0);
+            v_y[1] = vec_xl(16, y0);
+            y0 += 32;
+
+            v_xl[0] = (int8x16_t)vec_sr(v_x[0], 4);
+            v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);
+
+            const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
+            sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1];
+        }
+
+        sumf += d * (sumi1 + sumi2);
+    }
+
+    *s = sumf;
 #else
 
     const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -6388,6 +7344,118 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     *s = hsum_float_8(acc) + summs;
 
+#elif defined __wasm_simd128__
+    //const uint8_t * scales = (const uint8_t*)&utmp[0];
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin); // Fixed sign
+
+        const uint8_t * restrict q5 = x[i].qs;
+        const uint8_t * restrict qh = x[i].qh;
+        const int8_t  * restrict q8 = y[i].qs;
+
+        // Process scales and mins
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        // Sum mins * q8sums
+        int32_t sumi_mins = 0;
+        const int16_t * restrict q8sums = y[i].bsums;
+        const uint8_t * m = (const uint8_t *)&utmp[2];
+        for (int j = 0; j < 16; j += 2) {
+            sumi_mins += (q8sums[j] + q8sums[j+1]) * m[j/2];
+        }
+        sumf -= dmin * sumi_mins; // Correct subtraction
+
+        v128_t qh0 = wasm_v128_load(qh);
+        v128_t qh1 = wasm_v128_load(qh + 16);
+        const uint8_t * sc = (const uint8_t *)utmp;
+
+        int32_t sumi = 0;
+
+        for (int j = 0; j < QK_K/64; ++j) {
+            const int shift = j * 2;
+            v128_t qh_shift0 = wasm_u8x16_shr(qh0, shift);
+            v128_t qh_shift1 = wasm_u8x16_shr(qh1, shift);
+
+            v128_t qh_low0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x01)), 4);
+            v128_t qh_high0 = wasm_i8x16_shl(wasm_v128_and(qh_shift0, wasm_i8x16_splat(0x02)), 3);
+            v128_t qh_low1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x01)), 4);
+            v128_t qh_high1 = wasm_i8x16_shl(wasm_v128_and(qh_shift1, wasm_i8x16_splat(0x02)), 3);
+
+            v128_t q5_0 = wasm_v128_load(q5);
+            v128_t q5_1 = wasm_v128_load(q5 + 16);
+            q5 += 32;
+
+            v128_t q5l_0 = wasm_v128_or(wasm_v128_and(q5_0, wasm_i8x16_splat(0x0F)), qh_low0);
+            v128_t q5h_0 = wasm_v128_or(wasm_u8x16_shr(q5_0, 4), qh_high0);
+            v128_t q5l_1 = wasm_v128_or(wasm_v128_and(q5_1, wasm_i8x16_splat(0x0F)), qh_low1);
+            v128_t q5h_1 = wasm_v128_or(wasm_u8x16_shr(q5_1, 4), qh_high1);
+
+            v128_t q8_0 = wasm_v128_load(q8);
+            v128_t q8_1 = wasm_v128_load(q8 + 16);
+            v128_t q8_2 = wasm_v128_load(q8 + 32);
+            v128_t q8_3 = wasm_v128_load(q8 + 48);
+            q8 += 64;
+
+            // Process low quants
+            v128_t pl0 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q5l_0),
+                wasm_i16x8_extend_low_i8x16(q8_0)
+            );
+            pl0 = wasm_i32x4_add(pl0, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q5l_0),
+                wasm_i16x8_extend_high_i8x16(q8_0)
+            ));
+            v128_t pl1 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q5l_1),
+                wasm_i16x8_extend_low_i8x16(q8_1)
+            );
+            pl1 = wasm_i32x4_add(pl1, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q5l_1),
+                wasm_i16x8_extend_high_i8x16(q8_1)
+            ));
+            v128_t sum_low = wasm_i32x4_add(pl0, pl1);
+
+            // Process high quants
+            v128_t ph0 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q5h_0),
+                wasm_i16x8_extend_low_i8x16(q8_2)
+            );
+            ph0 = wasm_i32x4_add(ph0, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q5h_0),
+                wasm_i16x8_extend_high_i8x16(q8_2)
+            ));
+            v128_t ph1 = wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_low_i8x16(q5h_1),
+                wasm_i16x8_extend_low_i8x16(q8_3)
+            );
+            ph1 = wasm_i32x4_add(ph1, wasm_i32x4_dot_i16x8(
+                wasm_i16x8_extend_high_i8x16(q5h_1),
+                wasm_i16x8_extend_high_i8x16(q8_3)
+            ));
+            v128_t sum_high = wasm_i32x4_add(ph0, ph1);
+
+            // Accumulate with scale factors
+            int32_t sl = wasm_i32x4_extract_lane(sum_low, 0) + wasm_i32x4_extract_lane(sum_low, 1) +
+                        wasm_i32x4_extract_lane(sum_low, 2) + wasm_i32x4_extract_lane(sum_low, 3);
+            int32_t sh = wasm_i32x4_extract_lane(sum_high, 0) + wasm_i32x4_extract_lane(sum_high, 1) +
+                        wasm_i32x4_extract_lane(sum_high, 2) + wasm_i32x4_extract_lane(sum_high, 3);
+
+            sumi += sl * sc[2*j] + sh * sc[2*j+1];
+        }
+
+        sumf += d * sumi;
+    }
+
+    *s = sumf;
+
 #elif defined __riscv_v_intrinsic
 
     const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -6610,19 +7678,11 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     *s = vec_extract(vsumf0, 0);
 
 #elif defined __loongarch_asx
-    GGML_UNUSED(kmask1);
-    GGML_UNUSED(kmask2);
-    GGML_UNUSED(kmask3);
-
-    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
-    const __m128i mzero = __lsx_vldi(0);
-    const __m256i mone  = __lasx_xvreplgr2vr_b(1);
 
     __m256 acc = (__m256)__lasx_xvldi(0);
+    __m128 acc_m = (__m128)__lsx_vldi(0);
 
-    float summs = 0.f;
-
-   for (int i = 0; i < nb; ++i) {
+    for (int i = 0; i < nb; ++i) {
 
         const uint8_t * restrict q5 = x[i].qs;
         const int8_t  * restrict q8 = y[i].qs;
@@ -6637,49 +7697,40 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         utmp[2] = uaux;
         utmp[0] &= kmask1;
 
-        const __m256i mins_and_scales = lasx_extu8_16(lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]));
+        const __m128i mins_and_scales128 = lsx_set_w(utmp[3], utmp[2], utmp[1], utmp[0]);
+        const __m128i mins128 = __lsx_vexth_h_b(mins_and_scales128);
+        const __m128i scales128 = __lsx_vsllwil_h_b(mins_and_scales128, 0);
 
         const __m256i q8sums = __lasx_xvld((const __m256i*)y[i].bsums, 0);
         const __m128i q8s = lsx_hadd_h(lasx_extracti128(q8sums, 0), lasx_extracti128(q8sums, 1));
-        const __m128i prod = lsx_madd_h(lasx_extracti128(mins_and_scales, 1), q8s);
-        const __m128i hsum = lsx_hadd_w(lsx_hadd_w(prod, mzero), mzero);
-        summs += dmin * __lsx_vpickve2gr_w(hsum, 0);    //TODO check
+        const __m128i prod = lsx_madd_h(mins128, q8s);
+        acc_m = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(dmin), __lsx_vffint_s_w(prod), acc_m);
 
-        const __m128i sc128  = lasx_extracti128(mins_and_scales, 0);
-        const __m256i scales = lasx_insertf128(sc128, sc128);
+        const __m256i scales = lasx_insertf128(scales128, scales128);
 
         const __m256i hbits = __lasx_xvld((const __m256i*)x[i].qh, 0);
-        __m256i hmask = mone;
 
         __m256i sumi = __lasx_xvldi(0);
 
-        int bit = 0;
-        __m256i xvbit;
-
         for (int j = 0; j < QK_K/64; ++j) {
 
-            const __m256i scale_0 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+0));
-            const __m256i scale_1 = lasx_shuffle_b(scales, get_scale_shuffle_k4(2*j+1));
+            const __m256i scale_0 = lasx_xvrepl128vei_h(scales, 2 * j + 0);
+            const __m256i scale_1 = lasx_xvrepl128vei_h(scales, 2 * j + 1);
 
             const __m256i q5bits = __lasx_xvld((const __m256i*)q5, 0); q5 += 32;
 
-            xvbit = __lasx_xvreplgr2vr_h(bit++);
-            const __m256i q5l_0 = __lasx_xvand_v(q5bits, m4);
-            const __m256i q5h_0 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
-            const __m256i q5_0  = __lasx_xvadd_b(q5l_0, q5h_0);
-            hmask = __lasx_xvslli_h(hmask, 1);
-
-            xvbit = __lasx_xvreplgr2vr_h(bit++);
-            const __m256i q5l_1 = __lasx_xvand_v(__lasx_xvsrli_h(q5bits, 4), m4);
-            const __m256i q5h_1 = __lasx_xvslli_h(__lasx_xvsrl_h(__lasx_xvand_v(hbits, hmask), xvbit), 4);
-            const __m256i q5_1  = __lasx_xvadd_b(q5l_1, q5h_1);
-            hmask = __lasx_xvslli_h(hmask, 1);
+            const __m256i q5l_0 = __lasx_xvandi_b(q5bits, 0xf);
+            const __m256i q5l_1 = __lasx_xvsrli_b(q5bits, 4);
+            const __m256i q5h_0 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 0), 0), 0xef);
+            const __m256i q5h_1 = __lasx_xvnori_b(__lasx_xvseqi_b(lasx_xvandi_b_bit(hbits, 2 * j + 1), 0), 0xef);
+            const __m256i q5_0  = __lasx_xvor_v(q5l_0, q5h_0);
+            const __m256i q5_1  = __lasx_xvor_v(q5l_1, q5h_1);
 
             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
 
-            __m256i p16_0 = lasx_maddubs_h(q5_0, q8_0);
-            __m256i p16_1 = lasx_maddubs_h(q5_1, q8_1);
+            __m256i p16_0 = lasx_madd_h_b(q5_0, q8_0);
+            __m256i p16_1 = lasx_madd_h_b(q5_1, q8_1);
 
             p16_0 = lasx_madd_h(scale_0, p16_0);
             p16_1 = lasx_madd_h(scale_1, p16_1);
@@ -6693,8 +7744,98 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     }
 
-    *s = hsum_float_8(acc) + summs;
+    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 8));
+    acc_m = __lsx_vfadd_s(acc_m, (__m128)__lsx_vbsrl_v(acc_m, 4));
 
+    *s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
+#elif defined(__VXE__) || defined(__VXE2__)
+    const uint8x16_t v_lm = vec_splat_u8(0x0F);
+    const uint8x16_t v_1m = vec_splat_u8(0x01);
+    const uint8x16_t v_2m = vec_splat_u8(0x02);
+
+    const int32x4_t v_z = vec_splat_s32(0);
+
+    const uchar8x16_t v_minsm = {
+        0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F,
+        0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF
+    };
+
+    int8x16_t  q5b[4];
+    uint8x16_t q5h[4];
+
+    uint8x16_t v_xl[2];
+    uint8x16_t v_xh[2];
+    int8x16_t  v_y[4];
+
+    float sumf = 0;
+
+    for (int i = 0; i < nb; ++i) {
+        const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d);
+        const float dmin = y[i].d * GGML_FP16_TO_FP32(x[i].dmin);
+
+        const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
+        const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
+        const int16x8_t v_ysums = vec_padd_s16(v_ysumsl, v_ysumsh);
+
+        memcpy(utmp, x[i].scales, 12);
+        utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+        const uint32_t uaux = utmp[1] & kmask1;
+        utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+        utmp[2] = uaux;
+        utmp[0] &= kmask1;
+
+        const uint8x16_t v_mins16 = vec_xl(0, (const uint8_t *)utmp);
+        const uint8x16_t v_mins8 = vec_perm(v_mins16, v_mins16, v_minsm);
+        const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8);
+
+        const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
+        const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
+        const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
+        const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
+
+        const uint8_t * scales = (const uint8_t *)utmp;
+        const uint8_t * restrict x0l = x[i].qs;
+        const uint8_t * restrict x0h = x[i].qh;
+        const int8_t  * restrict y0 = y[i].qs;
+
+        v_xh[0] = vec_xl(0 , x0h);
+        v_xh[1] = vec_xl(16, x0h);
+
+        int32_t sumi = 0;
+        for (int j = 0; j < QK_K/64; ++j) {
+            v_xl[0] = vec_xl(0 , x0l);
+            v_xl[1] = vec_xl(16, x0l);
+            x0l += 32;
+
+            v_y[0] = vec_xl(0 , y0);
+            v_y[1] = vec_xl(16, y0);
+            v_y[2] = vec_xl(32, y0);
+            v_y[3] = vec_xl(48, y0);
+            y0 += 64;
+
+            q5h[0] = vec_sl(vec_and(v_1m, v_xh[0]), 4);
+            q5h[1] = vec_sl(vec_and(v_1m, v_xh[1]), 4);
+            q5h[2] = vec_sl(vec_and(v_2m, v_xh[0]), 3);
+            q5h[3] = vec_sl(vec_and(v_2m, v_xh[1]), 3);
+            v_xh[0] = vec_sr(v_xh[0], 2);
+            v_xh[1] = vec_sr(v_xh[1], 2);
+
+            q5b[0] = (int8x16_t)vec_or(vec_and(v_xl[0], v_lm), q5h[0]);
+            q5b[1] = (int8x16_t)vec_or(vec_and(v_xl[1], v_lm), q5h[1]);
+            q5b[2] = (int8x16_t)vec_or(vec_sr(v_xl[0], 4), q5h[2]);
+            q5b[3] = (int8x16_t)vec_or(vec_sr(v_xl[1], 4), q5h[3]);
+
+            int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);
+            int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);
+
+            sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++;
+            sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++;
+        }
+
+        sumf += d * sumi - dmin * mins;
+    }
+
+    *s = sumf;
 #else
 
     const uint8_t * scales = (const uint8_t*)&utmp[0];
@@ -7051,6 +8192,85 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
     *s = hsum_float_8(acc);
 
+#elif defined __wasm_simd128__
+    int8_t aux8[QK_K] __attribute__((aligned(16)));
+    int32_t aux32[8] __attribute__((aligned(16))) = {0};
+    float sums[8] __attribute__((aligned(16))) = {0};
+
+    for (int i = 0; i < nb; ++i) {
+        // Unpack 6-bit quantized data into aux8 (unchanged)
+        const uint8_t * restrict q4 = x[i].ql;
+        const uint8_t * restrict qh = x[i].qh;
+        int8_t * a = aux8;
+        for (int j = 0; j < QK_K; j += 128) {
+            for (int l = 0; l < 32; ++l) {
+                a[l +  0] = (int8_t)((q4[l +  0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
+                a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
+                a[l + 64] = (int8_t)((q4[l +  0] >>  4) | (((qh[l] >> 4) & 3) << 4)) - 32;
+                a[l + 96] = (int8_t)((q4[l + 32] >>  4) | (((qh[l] >> 6) & 3) << 4)) - 32;
+            }
+            a += 128;
+            q4 += 64;
+            qh += 32;
+        }
+
+        const int8_t * restrict a_ptr = aux8;
+        const int8_t * restrict q8 = y[i].qs;
+        v128_t acc0 = wasm_i32x4_splat(0);
+        v128_t acc1 = wasm_i32x4_splat(0);
+
+        for (int j = 0; j < QK_K/16; ++j) {
+            const int scale = x[i].scales[j];
+            const v128_t vscale = wasm_i32x4_splat(scale);
+
+            // Load 16 elements from a and q8
+            const v128_t a_vec = wasm_v128_load(a_ptr);
+            const v128_t q8_vec = wasm_v128_load(q8);
+
+            // Process low 8 elements
+            v128_t a_low = wasm_i16x8_extend_low_i8x16(a_vec);
+            v128_t q8_low = wasm_i16x8_extend_low_i8x16(q8_vec);
+            v128_t prod_low = wasm_i16x8_mul(a_low, q8_low);
+            v128_t prod_lo_lo = wasm_i32x4_extend_low_i16x8(prod_low);
+            v128_t prod_lo_hi = wasm_i32x4_extend_high_i16x8(prod_low);
+
+            // Process high 8 elements
+            v128_t a_high = wasm_i16x8_extend_high_i8x16(a_vec);
+            v128_t q8_high = wasm_i16x8_extend_high_i8x16(q8_vec);
+            v128_t prod_high = wasm_i16x8_mul(a_high, q8_high);
+            v128_t prod_hi_lo = wasm_i32x4_extend_low_i16x8(prod_high);
+            v128_t prod_hi_hi = wasm_i32x4_extend_high_i16x8(prod_high);
+
+            // Scale and accumulate
+            prod_lo_lo = wasm_i32x4_mul(prod_lo_lo, vscale);
+            prod_lo_hi = wasm_i32x4_mul(prod_lo_hi, vscale);
+            prod_hi_lo = wasm_i32x4_mul(prod_hi_lo, vscale);
+            prod_hi_hi = wasm_i32x4_mul(prod_hi_hi, vscale);
+
+            acc0 = wasm_i32x4_add(acc0, wasm_i32x4_add(prod_lo_lo, prod_hi_lo));
+            acc1 = wasm_i32x4_add(acc1, wasm_i32x4_add(prod_lo_hi, prod_hi_hi));
+
+            a_ptr += 16;
+            q8 += 16;
+        }
+
+        // Store accumulated results
+        wasm_v128_store(&aux32[0], acc0);
+        wasm_v128_store(&aux32[4], acc1);
+
+        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+        for (int l = 0; l < 8; ++l) {
+            sums[l] += d * aux32[l];
+        }
+    }
+
+    // Sum final results
+    float sumf = 0;
+    for (int l = 0; l < 8; ++l) {
+        sumf += sums[l];
+    }
+    *s = sumf;
+
 #elif defined __riscv_v_intrinsic
 
     float sumf = 0;
@@ -7275,8 +8495,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
 
 #elif defined __loongarch_asx
 
-    const __m256i m4 = __lasx_xvreplgr2vr_b(0xF);
-    const __m256i m2 = __lasx_xvreplgr2vr_b(3);
     const __m256i m32s = __lasx_xvreplgr2vr_b(32);
 
     __m256 acc = (__m256)__lasx_xvldi(0);
@@ -7289,58 +8507,42 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
         const uint8_t * restrict qh = x[i].qh;
         const int8_t  * restrict q8 = y[i].qs;
 
-        const __m128i scales = __lsx_vld((const __m128i*)x[i].scales, 0);
+        const __m128i scales128 = __lsx_vld((const __m128i*)x[i].scales, 0);
+        const v16i8 shuffle_mask = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15};
+        const __m256i scales_shuffled = lasx_ext8_16(__lsx_vshuf_b(scales128, scales128, (__m128i)shuffle_mask));
 
         __m256i sumi = __lasx_xvldi(0);
 
-        int is = 0;
-
         for (int j = 0; j < QK_K/128; ++j) {
 
-            const __m128i scale_0 = lsx_shuffle_b(scales, get_scale_shuffle(is + 0));
-            const __m128i scale_1 = lsx_shuffle_b(scales, get_scale_shuffle(is + 1));
-            const __m128i scale_2 = lsx_shuffle_b(scales, get_scale_shuffle(is + 2));
-            const __m128i scale_3 = lsx_shuffle_b(scales, get_scale_shuffle(is + 3));
-            is += 4;
-
             const __m256i q4bits1 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
             const __m256i q4bits2 = __lasx_xvld((const __m256i*)q4, 0); q4 += 32;
             const __m256i q4bitsH = __lasx_xvld((const __m256i*)qh, 0); qh += 32;
 
-            const __m256i q4h_0 = __lasx_xvslli_h(__lasx_xvand_v(q4bitsH, m2), 4);
-            const __m256i q4h_1 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 2), m2), 4);
-            const __m256i q4h_2 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 4), m2), 4);
-            const __m256i q4h_3 = __lasx_xvslli_h(__lasx_xvand_v(__lasx_xvsrli_h(q4bitsH, 6), m2), 4);
+            const __m256i q4h_0 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3), 4);
+            const __m256i q4h_1 = __lasx_xvslli_b(__lasx_xvandi_b(q4bitsH, 3 << 2), 2);
+            const __m256i q4h_2 = __lasx_xvandi_b(q4bitsH, 3 << 4);
+            const __m256i q4h_3 = __lasx_xvsrli_b(__lasx_xvandi_b(q4bitsH, 3 << 6), 2);
 
-            const __m256i q4_0 = __lasx_xvor_v(__lasx_xvand_v(q4bits1, m4), q4h_0);
-            const __m256i q4_1 = __lasx_xvor_v(__lasx_xvand_v(q4bits2, m4), q4h_1);
-            const __m256i q4_2 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits1, 4), m4), q4h_2);
-            const __m256i q4_3 = __lasx_xvor_v(__lasx_xvand_v(__lasx_xvsrli_h(q4bits2, 4), m4), q4h_3);
+            const __m256i q4_0 = __lasx_xvor_v(__lasx_xvandi_b(q4bits1, 0xf), q4h_0);
+            const __m256i q4_1 = __lasx_xvor_v(__lasx_xvandi_b(q4bits2, 0xf), q4h_1);
+            const __m256i q4_2 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits1, 4), q4h_2);
+            const __m256i q4_3 = __lasx_xvor_v(__lasx_xvsrli_b(q4bits2, 4), q4h_3);
 
             const __m256i q8_0 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_1 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_2 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
             const __m256i q8_3 = __lasx_xvld((const __m256i*)q8, 0); q8 += 32;
 
-            __m256i q8s_0 = lasx_maddubs_h(m32s, q8_0);
-            __m256i q8s_1 = lasx_maddubs_h(m32s, q8_1);
-            __m256i q8s_2 = lasx_maddubs_h(m32s, q8_2);
-            __m256i q8s_3 = lasx_maddubs_h(m32s, q8_3);
+            __m256i p16_0 = lasx_madd_h_b(__lasx_xvsub_b(q4_0, m32s), q8_0);
+            __m256i p16_1 = lasx_madd_h_b(__lasx_xvsub_b(q4_1, m32s), q8_1);
+            __m256i p16_2 = lasx_madd_h_b(__lasx_xvsub_b(q4_2, m32s), q8_2);
+            __m256i p16_3 = lasx_madd_h_b(__lasx_xvsub_b(q4_3, m32s), q8_3);
 
-            __m256i p16_0 = lasx_maddubs_h(q4_0, q8_0);
-            __m256i p16_1 = lasx_maddubs_h(q4_1, q8_1);
-            __m256i p16_2 = lasx_maddubs_h(q4_2, q8_2);
-            __m256i p16_3 = lasx_maddubs_h(q4_3, q8_3);
-
-            p16_0 = __lasx_xvsub_h(p16_0, q8s_0);
-            p16_1 = __lasx_xvsub_h(p16_1, q8s_1);
-            p16_2 = __lasx_xvsub_h(p16_2, q8s_2);
-            p16_3 = __lasx_xvsub_h(p16_3, q8s_3);
-
-            p16_0 = lasx_madd_h(lasx_ext8_16(scale_0), p16_0);
-            p16_1 = lasx_madd_h(lasx_ext8_16(scale_1), p16_1);
-            p16_2 = lasx_madd_h(lasx_ext8_16(scale_2), p16_2);
-            p16_3 = lasx_madd_h(lasx_ext8_16(scale_3), p16_3);
+            p16_0 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 0), p16_0);
+            p16_1 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 1), p16_1);
+            p16_2 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 2), p16_2);
+            p16_3 = lasx_madd_h(lasx_xvrepl128vei_h(scales_shuffled, 4 * j + 3), p16_3);
 
             sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_0, p16_1));
             sumi = __lasx_xvadd_w(sumi, __lasx_xvadd_w(p16_2, p16_3));
@@ -7350,7 +8552,130 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, size_t bs, const void * r
     }
 
     *s = hsum_float_8(acc);
+#elif defined(__VXE__) || defined(__VXE2__)
+    float sum = 0;
 
+    // Lower 4-bit and upper 2-bit masks
+    const uint8x16_t v_lm = vec_splat_u8(0x0F);
+    const uint8x16_t v_um = vec_splat_u8(0x03);
+
+    const int32x4_t v_z = vec_splat_s32(0);
+
+    int8x16_t  q6b[4];
+    uint8x16_t q6h[4];
+
+    uint8x16_t v_xl[4];
+    uint8x16_t v_xh[2];
+    int8x16_t  v_y[4];
+
+    for (int i = 0; i < nb; ++i) {
+        const float d_all = GGML_FP16_TO_FP32(x[i].d);
+
+        const uint8_t * restrict x0l = x[i].ql;
+        const uint8_t * restrict x0h = x[i].qh;
+        const int8_t  * restrict y0 = y[i].qs;
+
+        const int8_t  * restrict scale = x[i].scales;
+
+        const int16x8_t v_ysumsl = vec_xl(0 , y[i].bsums);
+        const int16x8_t v_ysumsh = vec_xl(16, y[i].bsums);
+
+        const int8x16_t v_scale  = vec_xl(0, scale);
+        const int16x8_t v_scalel = vec_unpackh(v_scale);
+        const int16x8_t v_scaleh = vec_unpackl(v_scale);
+
+        const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel);
+        const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel);
+        const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh);
+        const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
+        const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
+
+        const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
+
+        int32_t isum = 0;
+        for (int j = 0; j < QK_K/128; ++j) {
+            // Load model upper 2 bits
+            v_xh[0] = vec_xl(0 , x0h);
+            v_xh[1] = vec_xl(16, x0h);
+            x0h += 32;
+
+            // Load model lower 4 bits
+            v_xl[0] = vec_xl(0 , x0l);
+            v_xl[1] = vec_xl(16, x0l);
+            v_xl[2] = vec_xl(32, x0l);
+            v_xl[3] = vec_xl(48, x0l);
+            x0l += 64;
+
+            // Load activation quants
+            v_y[0] = vec_xl(0 , y0);
+            v_y[1] = vec_xl(16, y0);
+            v_y[2] = vec_xl(32, y0);
+            v_y[3] = vec_xl(48, y0);
+            y0 += 64;
+
+            q6h[0] = vec_sl(vec_and(v_um, v_xh[0]), 4);
+            q6h[1] = vec_sl(vec_and(v_um, v_xh[1]), 4);
+            uint8x16_t shifted = vec_sr(v_xh[0], 2);
+            q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
+            shifted = vec_sr(v_xh[1], 2);
+            q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
+
+            q6b[0] = (int8x16_t)(vec_or(vec_and(v_xl[0], v_lm), q6h[0]));
+            q6b[1] = (int8x16_t)(vec_or(vec_and(v_xl[1], v_lm), q6h[1]));
+            q6b[2] = (int8x16_t)(vec_or(vec_and(v_xl[2], v_lm), q6h[2]));
+            q6b[3] = (int8x16_t)(vec_or(vec_and(v_xl[3], v_lm), q6h[3]));
+
+            int32x4_t summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
+            int32x4_t summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
+            int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
+            int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
+
+            isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
+                    (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
+                    (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
+                    (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
+
+            scale += 4;
+
+
+            // Load activation quants
+            v_y[0] = vec_xl(0 , y0);
+            v_y[1] = vec_xl(16, y0);
+            v_y[2] = vec_xl(32, y0);
+            v_y[3] = vec_xl(48, y0);
+            y0 += 64;
+
+            shifted = vec_sr(v_xh[0], 4);
+            q6h[0] = vec_sl(vec_and(v_um, shifted), 4);
+            shifted = vec_sr(v_xh[1], 4);
+            q6h[1] = vec_sl(vec_and(v_um, shifted), 4);
+            shifted = vec_sr(v_xh[0], 6);
+            q6h[2] = vec_sl(vec_and(v_um, shifted), 4);
+            shifted = vec_sr(v_xh[1], 6);
+            q6h[3] = vec_sl(vec_and(v_um, shifted), 4);
+
+            q6b[0] = (int8x16_t)(vec_or(vec_sr(v_xl[0], 4), q6h[0]));
+            q6b[1] = (int8x16_t)(vec_or(vec_sr(v_xl[1], 4), q6h[1]));
+            q6b[2] = (int8x16_t)(vec_or(vec_sr(v_xl[2], 4), q6h[2]));
+            q6b[3] = (int8x16_t)(vec_or(vec_sr(v_xl[3], 4), q6h[3]));
+
+            summs0 = ggml_vec_dot(v_z, q6b[0], v_y[0]);
+            summs1 = ggml_vec_dot(v_z, q6b[1], v_y[1]);
+            summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
+            summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
+
+            isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
+                    (summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
+                    (summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
+                    (summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
+
+            scale += 4;
+        }
+
+        sum += d_all * y[i].d * (isum - 32 * mins);
+    }
+
+    *s = sum;
 #else
 
     int8_t  aux8[QK_K];
@@ -7711,7 +9036,57 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, size_t bs, const void
     }
 
     *s = 0.125f * hsum_float_8(accumf);
-
+//#elif defined(__VXE__) || defined(__VXE2__)
+//    const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
+//
+//    uint32_t aux32[4];
+//    const uint8_t * aux8 = (const uint8_t *)aux32;
+//
+//    float sumf = 0;
+//
+//    for (int i = 0; i < nb; ++i) {
+//        const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
+//        const uint16_t * restrict q2 = x[i].qs;
+//        const int8_t   * restrict q8 = y[i].qs;
+//
+//        float sumf1 = 0, sumf2 = 0;
+//
+//        for (int ib32 = 0; ib32 < QK_K/32; ib += 2) {
+//            int8x16_t q8b0 = vec_xl( 0, q8);
+//            int8x16_t qb81 = vec_xl(16, q8);
+//            int8x16_t q8b2 = vec_xl(32, q8);
+//            int8x16_t q8b3 = vec_xl(48, q8);
+//            q8 += 64;
+//
+//            memcpy(aux32, q2, 4 * sizeof(uint32_t));
+//            q2 += 8;
+//
+//            int8x16_t q2u0 = { *(const int64_t *)(iq2xxs_grid + aux8[ 0]), *(const int64_t *)(iq2xxs_grid + aux8[ 1]) };
+//            int8x16_t q2u1 = { *(const int64_t *)(iq2xxs_grid + aux8[ 2]), *(const int64_t *)(iq2xxs_grid + aux8[ 3]) };
+//            int8x16_t q2u2 = { *(const int64_t *)(iq2xxs_grid + aux8[ 8]), *(const int64_t *)(iq2xxs_grid + aux8[ 9]) };
+//            int8x16_t q2u3 = { *(const int64_t *)(iq2xxs_grid + aux8[10]), *(const int64_t *)(iq2xxs_grid + aux8[11]) };
+//
+//            int8x16_t q2s0 = { *(const int64_t *)(signs64 + ((aux32[1] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >>  7) & 127)) };
+//            int8x16_t q2s1 = { *(const int64_t *)(signs64 + ((aux32[1] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[1] >> 21) & 127)) };
+//            int8x16_t q2s2 = { *(const int64_t *)(signs64 + ((aux32[3] >>  0) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >>  7) & 127)) };
+//            int8x16_t q2s3 = { *(const int64_t *)(signs64 + ((aux32[3] >> 14) & 127)), *(const int64_t *)(signs64 + ((aux32[3] >> 21) & 127)) };
+//
+//            q2u0 = vec_mul(q2u0, q2s0);
+//            q2u1 = vec_mul(q2u1, q2s1);
+//            q2u2 = vec_mul(q2u2, q2s2);
+//            q2u3 = vec_mul(q2u3, q2s3);
+//
+//            const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u0, q8b0), q2u1, q8b1);
+//            const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(vec_splat_s32(0), q2u2, q8b2), q2u3, q8b3);
+//
+//            sumf1 += (p1[0] + p1[1] + p1[2] + p1[3]) * (0.5f + (aux32[1] >> 28));
+//            sumf2 += (p2[0] + p2[1] + p2[2] + p2[3]) * (0.5f + (aux32[3] >> 28));
+//        }
+//
+//        sumf += d * (sumf1 + sumf2);
+//    }
+//
+//    *s = 0.25f * sumf;
 #else
 
     uint32_t aux32[2];
@@ -9665,13 +11040,9 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
 }
 #elif defined(__loongarch_asx)
 static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
-    const __m256i ax = __lasx_xvsigncov_b(x, x);
-    const __m256i sy = __lasx_xvsigncov_b(x, y);
-    __m256i tmp1, tmp2, tmp3;
-    tmp1 = __lasx_xvmulwev_h_bu_b(ax, sy);
-    tmp2 = __lasx_xvmulwod_h_bu_b(ax, sy);
-    tmp3 = __lasx_xvadd_h(tmp1, tmp2);
-    return __lasx_xvsat_h(tmp3, 15);
+    const __m256i a = __lasx_xvmulwev_h_b(x, y);
+    const __m256i b = __lasx_xvmulwod_h_b(x, y);
+    return __lasx_xvadd_h(a, b);
 }
 #endif
 
@@ -10476,6 +11847,27 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void *
 
     sumf = hsum_float_8(__lasx_xvfadd_s(accum1, accum2));
 
+#elif defined(__VXE__) || defined(__VXE2__)
+    const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
+    const uint8x16_t v_m = vec_splat_u8(0x0F);
+
+    for (; ib < nb; ++ib) {
+        const block_iq4_nl * restrict x0 = &x[ib];
+        const block_q8_0   * restrict y0 = &y[ib];
+
+        const uint8x16_t v_x = vec_xl(0, x0->qs);
+        int8x16_t v_xl = (int8x16_t)vec_and(v_x, v_m);
+        int8x16_t v_xh = (int8x16_t)vec_sr(v_x, 4);
+
+        v_xl = vec_perm(v_k, v_k, (uchar8x16_t)v_xl);
+        v_xh = vec_perm(v_k, v_k, (uchar8x16_t)v_xh);
+
+        const int8x16_t v_yl = vec_xl(0      , y0->qs);
+        const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
+        const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
+
+        sumf += GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]);
+    }
 #endif
     for (; ib < nb; ++ib) {
         const float d = GGML_FP16_TO_FP32(y[ib].d)*GGML_FP16_TO_FP32(x[ib].d);
@@ -10721,67 +12113,31 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
 #elif defined(__loongarch_asx)
 
     const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0);
-    const __m128i m4b  = __lsx_vreplgr2vr_b(0x0f);
 
     __m256 accum = (__m256)__lasx_xvldi(0);
-    __m256i tmp1;
-    __m128i tmp0, tmp2, tmp3, tmp4, mask_8f, mask;
 
-    mask_8f = __lsx_vreplgr2vr_b(0x8f);
     for (int ibl = 0; ibl < nb; ++ibl) {
         const uint8_t * qs = x[ibl].qs;
         const int8_t  * q8 = y[ibl].qs;
         uint16_t sh = x[ibl].scales_h;
         __m256i sumi1 = __lasx_xvldi(0);
         __m256i sumi2 = __lasx_xvldi(0);
-        __m128i zero = __lsx_vldi(0);
         for (int ib = 0; ib < QK_K/32; ib += 2) {
-            const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0);  qs += 16;
-            const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0);  qs += 16;
+            const __m128i q4bits_1 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
+            const __m128i q4bits_2 = __lsx_vld((const __m128i*)qs, 0); qs += 16;
             const __m256i q8b_1 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
             const __m256i q8b_2 = __lasx_xvld((const __m256i *)q8, 0); q8 += 32;
-            tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_1, 4), m4b), mask_8f);
-            tmp0 = __lsx_vori_b(tmp2, 0x10);
-            mask = __lsx_vsle_b(zero, tmp2);
-            tmp3 = __lsx_vand_v(tmp0, mask);
-            tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
-
-            tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_1, m4b), mask_8f);
-            tmp0 = __lsx_vori_b(tmp2, 0x10);
-            mask = __lsx_vsle_b(zero, tmp2);
-            tmp4 = __lsx_vand_v(tmp0, mask);
-            tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
-
-            const __m256i q4b_1 = lasx_insertf128(tmp3, tmp4);
-
-            tmp2 = __lsx_vand_v(__lsx_vand_v(__lsx_vsrli_h(q4bits_2, 4), m4b), mask_8f);
-            tmp0 = __lsx_vori_b(tmp2, 0x10);
-            mask = __lsx_vsle_b(zero, tmp2);
-            tmp3 = __lsx_vand_v(tmp0, mask);
-            tmp3 = __lsx_vshuf_b(values128, zero, tmp3);
-
-            tmp2 = __lsx_vand_v(__lsx_vand_v(q4bits_2, m4b), mask_8f);
-            tmp0 = __lsx_vori_b(tmp2, 0x10);
-            mask = __lsx_vsle_b(zero, tmp2);
-            tmp4 = __lsx_vand_v(tmp0, mask);
-            tmp4 = __lsx_vshuf_b(values128, zero, tmp4);
-
-            const __m256i q4b_2 = lasx_insertf128(tmp3, tmp4);
-
+            const __m256i q4b_1 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_1, 4)),
+                                                  __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_1, 0xf)));
+            const __m256i q4b_2 = lasx_insertf128(__lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits_2, 4)),
+                                                  __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits_2, 0xf)));
             const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
             const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
             const int16_t ls1 = ((x[ibl].scales_l[ib/2] & 0xf) | ((sh << 4) & 0x30)) - 32;
             const int16_t ls2 = ((x[ibl].scales_l[ib/2] >>  4) | ((sh << 2) & 0x30)) - 32;
             sh >>= 4;
-            __m256i tmp5, tmp6;
-            tmp1 = __lasx_xvreplgr2vr_h(ls1);
-            tmp5 = __lasx_xvmulwev_w_h(p16_1, tmp1);
-            tmp6 = __lasx_xvmulwod_w_h(p16_1, tmp1);
-            const __m256i p_1 = __lasx_xvadd_w(tmp5, tmp6);
-            tmp1 = __lasx_xvreplgr2vr_h(ls2);
-            tmp5 = __lasx_xvmulwev_w_h(p16_2, tmp1);
-            tmp6 = __lasx_xvmulwod_w_h(p16_2, tmp1);
-            const __m256i p_2 = __lasx_xvadd_w(tmp5, tmp6);
+            const __m256i p_1 = lasx_madd_h(p16_1, __lasx_xvreplgr2vr_h(ls1));
+            const __m256i p_2 = lasx_madd_h(p16_2, __lasx_xvreplgr2vr_h(ls2));
             sumi1 = __lasx_xvadd_w(p_1, sumi1);
             sumi2 = __lasx_xvadd_w(p_2, sumi2);
         }
@@ -10790,6 +12146,56 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * restrict s, size_t bs, const void *
     }
 
     *s = hsum_float_8(accum);
+#elif defined(__VXE__) || defined(__VXE2__)
+    const int8x16_t v_k = vec_xl(0, kvalues_iq4nl);
+    const uint8x16_t v_m = vec_splat_u8(0x0F);
+
+    float sumf = 0;
+
+    for (int ibl = 0; ibl < nb; ++ibl) {
+        const uint8_t * restrict q4 = x[ibl].qs;
+        const int8_t  * restrict q8 = y[ibl].qs;
+
+        uint16_t h = x[ibl].scales_h;
+
+        int sumi1 = 0, sumi2 = 0;
+        for (int ib = 0; ib < QK_K/64; ++ib) {
+            const uint8x16_t v_x0 = vec_xl(0       , q4);
+            const uint8x16_t v_x1 = vec_xl(QK4_NL/2, q4);
+            q4 += 32;
+
+            int8x16_t v_x0l = (int8x16_t)vec_and(v_x0, v_m);
+            int8x16_t v_x0h = (int8x16_t)vec_sr(v_x0, 4);
+            int8x16_t v_x1l = (int8x16_t)vec_and(v_x1, v_m);
+            int8x16_t v_x1h = (int8x16_t)vec_sr(v_x1, 4);
+
+            v_x0l = vec_perm(v_k, v_k, (uchar8x16_t)v_x0l);
+            v_x0h = vec_perm(v_k, v_k, (uchar8x16_t)v_x0h);
+            v_x1l = vec_perm(v_k, v_k, (uchar8x16_t)v_x1l);
+            v_x1h = vec_perm(v_k, v_k, (uchar8x16_t)v_x1h);
+
+            const int8x16_t v_y0 = vec_xl( 0, q8);
+            const int8x16_t v_y1 = vec_xl(16, q8);
+            const int8x16_t v_y2 = vec_xl(32, q8);
+            const int8x16_t v_y3 = vec_xl(48, q8);
+            q8 += 64;
+
+            int32x4_t vsumi0 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x0l, v_y0), v_x0h, v_y1);
+            int32x4_t vsumi1 = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_x1l, v_y2), v_x1h, v_y3);
+
+            int ls1 = ((x[ibl].scales_l[ib] & 0xF) | ((h << 4) & 0x30)) - 32;
+            int ls2 = ((x[ibl].scales_l[ib] >>  4) | ((h << 2) & 0x30)) - 32;
+
+            h >>= 4;
+
+            sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1;
+            sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2;
+        }
+
+        sumf += GGML_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);
+    }
+
+    *s = sumf;
 
 #else
     float sumf = 0;
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
index b307d554..2f606d82 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c
@@ -7,10 +7,8 @@
 #include "ggml-cpu-impl.h"
 #include "ggml-cpu.h"
 #include "ggml-impl.h"
-#include "ggml-quants.h"
 #include "ggml-cpu-quants.h"
 #include "ggml-threading.h"
-#include "amx/amx.h"
 #include "ggml.h"
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
@@ -114,7 +112,8 @@ struct ggml_arm_arch_features_type {
     int has_i8mm;
     int has_sve;
     int sve_cnt;
-} ggml_arm_arch_features = {-1, -1, -1, -1, 0};
+    int has_sme;
+} ggml_arm_arch_features = {-1, -1, -1, -1, 0, -1};
 #endif
 
 
@@ -238,6 +237,8 @@ typedef pthread_t ggml_thread_t;
 #else
 #if defined(__POWER9_VECTOR__)
 #define CACHE_LINE_SIZE 128
+#elif defined(__VXE__) || defined(__VXE2__)
+#define CACHE_LINE_SIZE 256
 #else
 #define CACHE_LINE_SIZE 64
 #endif
@@ -1078,29 +1079,23 @@ do {                                                              \
 #define GGML_F16_STEP 32
 #define GGML_F16_EPR  8
 
-// F16 arithmetic is not supported by AVX, so we use F32 instead
+// F16 arithmetic is not supported by LASX, so we use F32 instead
 
 #define GGML_F32Cx8          __m256
 #define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0)
 #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
 
 static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
-    float tmp[8];
-
-    for (int i = 0; i < 8; i++) {
-        tmp[i] = GGML_FP16_TO_FP32(x[i]);
-    }
-
-    return (__m256)__lasx_xvld(tmp, 0);
+    __m256i a;
+    memcpy(&a, x, sizeof(ggml_fp16_t) * 8);
+    a = __lasx_xvpermi_d(a, 0 | (1 << 4));
+    return __lasx_xvfcvtl_s_h(a);
 }
+
 static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
-    float arr[8];
-
-    __lasx_xvst(y, arr, 0);
-
-    for (int i = 0; i < 8; i++) {
-        x[i] = GGML_FP32_TO_FP16(arr[i]);
-    }
+    __m256i a = __lasx_xvfcvt_h_s(y, y);
+    a = __lasx_xvpermi_d(a, 0 | (2 << 2));
+    memcpy(x, &a, sizeof(ggml_fp16_t) * 8);
 }
 #define GGML_F32Cx8_LOAD(x)     __lasx_f32cx8_load(x)
 #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
@@ -1218,6 +1213,87 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
 #define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL
 #define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE
 
+#elif defined(__VXE__) || defined(__VXE2__)
+
+#define GGML_SIMD
+
+// F32 s390x
+
+#define GGML_F32_STEP 32
+#define GGML_F32_EPR  4
+
+#define GGML_F32x4              __vector float
+#define GGML_F32x4_ZERO         vec_splats(0.0f)
+#define GGML_F32x4_SET1         vec_splats
+#define GGML_F32x4_LOAD(p)      vec_xl(0, p)
+#define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p)
+#define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a)
+#define GGML_F32x4_ADD          vec_add
+#define GGML_F32x4_MUL          vec_mul
+#define GGML_F32x4_REDUCE(res, x)                   \
+{                                                   \
+    int offset = GGML_F32_ARR >> 1;                 \
+    for (int i = 0; i < offset; ++i) {              \
+        x[i] = vec_add(x[i], x[offset + i]);        \
+    }                                               \
+    offset >>= 1;                                   \
+    for (int i = 0; i < offset; ++i) {              \
+        x[i] = vec_add(x[i], x[offset + i]);        \
+    }                                               \
+    offset >>= 1;                                   \
+    for (int i = 0; i < offset; ++i) {              \
+        x[i] = vec_add(x[i], x[offset + i]);        \
+    }                                               \
+    res = vec_extract(x[0], 0) +                    \
+          vec_extract(x[0], 1) +                    \
+          vec_extract(x[0], 2) +                    \
+          vec_extract(x[0], 3);                     \
+}
+
+#define GGML_F32_VEC        GGML_F32x4
+#define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO
+#define GGML_F32_VEC_SET1   GGML_F32x4_SET1
+#define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD
+#define GGML_F32_VEC_STORE  GGML_F32x4_STORE
+#define GGML_F32_VEC_FMA    GGML_F32x4_FMA
+#define GGML_F32_VEC_ADD    GGML_F32x4_ADD
+#define GGML_F32_VEC_MUL    GGML_F32x4_MUL
+#define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE
+
+// F16 s390x
+#define GGML_F16_STEP GGML_F32_STEP
+#define GGML_F16_EPR  GGML_F32_EPR
+
+static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) {
+    float tmp[4];
+
+    for (int i = 0; i < 4; i++) {
+        tmp[i] = GGML_FP16_TO_FP32(x[i]);
+    }
+
+    return vec_xl(0, tmp);
+}
+
+static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) {
+    float arr[4];
+
+    vec_xst(y, 0, arr);
+
+    for (int i = 0; i < 4; i++) {
+        x[i] = GGML_FP32_TO_FP16(arr[i]);
+    }
+}
+
+#define GGML_F16_VEC                GGML_F32x4
+#define GGML_F16_VEC_ZERO           GGML_F32x4_ZERO
+#define GGML_F16_VEC_SET1           GGML_F32x4_SET1
+#define GGML_F16_VEC_LOAD(p, i)     __lzs_f16cx4_load(p)
+#define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i])
+#define GGML_F16_VEC_FMA            GGML_F32x4_FMA
+#define GGML_F16_VEC_ADD            GGML_F32x4_ADD
+#define GGML_F16_VEC_MUL            GGML_F32x4_MUL
+#define GGML_F16_VEC_REDUCE         GGML_F32x4_REDUCE
+
 #endif
 
 // GGML_F32_ARR / GGML_F16_ARR
@@ -1297,12 +1373,12 @@ struct ggml_threadpool {
     atomic_int n_graph;       // incremented when there is work to be done (i.e each graph)
     atomic_int GGML_CACHE_ALIGN n_barrier;
     atomic_int GGML_CACHE_ALIGN n_barrier_passed;
-    atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
+    atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
 
     // these are atomic as an annotation for thread-sanitizer
     atomic_bool stop;         // Used for stopping the threadpool altogether
     atomic_bool pause;        // Used for pausing the threadpool or individual threads
-    atomic_bool abort;        // Used for aborting processing of a graph
+    atomic_int abort;         // Used for aborting processing of a graph
 
     struct ggml_compute_state * workers;   // per thread state
     int          n_threads_max; // number of threads in the pool
@@ -1824,7 +1900,7 @@ inline static float ggml_silu_f32(float x) {
 
 #if __FINITE_MATH_ONLY__
 #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix"
-#error "ref: https://github.com/ggerganov/llama.cpp/pull/7154#issuecomment-2143844461"
+#error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461"
 #endif
 
 #if defined(__ARM_NEON) && defined(__aarch64__)
@@ -2389,15 +2465,20 @@ bool ggml_is_numa(void) {
 #define HWCAP2_I8MM (1 << 13)
 #endif
 
+#if !defined(HWCAP2_SME)
+#define HWCAP2_SME (1 << 23)
+#endif
+
 static void ggml_init_arm_arch_features(void) {
 #if defined(__linux__) && defined(__aarch64__)
     uint32_t hwcap = getauxval(AT_HWCAP);
     uint32_t hwcap2 = getauxval(AT_HWCAP2);
 
-    ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
+    ggml_arm_arch_features.has_neon    = !!(hwcap & HWCAP_ASIMD);
     ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP);
-    ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
-    ggml_arm_arch_features.has_sve  = !!(hwcap & HWCAP_SVE);
+    ggml_arm_arch_features.has_i8mm    = !!(hwcap2 & HWCAP2_I8MM);
+    ggml_arm_arch_features.has_sve     = !!(hwcap & HWCAP_SVE);
+    ggml_arm_arch_features.has_sme     = !!(hwcap2 & HWCAP2_SME);
 
 #if defined(__ARM_FEATURE_SVE)
     ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
@@ -2420,6 +2501,11 @@ static void ggml_init_arm_arch_features(void) {
     }
     ggml_arm_arch_features.has_i8mm = oldp;
 
+    if (sysctlbyname("hw.optional.arm.FEAT_SME", &oldp, &size, NULL, 0) != 0) {
+        oldp = 0;
+    }
+    ggml_arm_arch_features.has_sme = oldp;
+
     ggml_arm_arch_features.has_sve = 0;
     ggml_arm_arch_features.sve_cnt = 0;
 #else
@@ -2443,6 +2529,12 @@ static void ggml_init_arm_arch_features(void) {
     ggml_arm_arch_features.has_sve = 0;
     ggml_arm_arch_features.sve_cnt = 0;
 #endif
+
+#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_SME2)
+    ggml_arm_arch_features.has_sme = 1;
+#else
+    ggml_arm_arch_features.has_sme = 0;
+#endif
 #endif
 }
 #endif
@@ -3967,6 +4059,57 @@ static void ggml_compute_forward_dup_bytes(
     }
 }
 
+static void ggml_compute_forward_dup_q(
+        const struct ggml_compute_params * params,
+              struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const enum ggml_type type = src0->type;
+    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
+
+    size_t qk = ggml_blck_size(type);
+    const int64_t nr = ggml_nelements(src1) / qk;
+
+    // destination must be contiguous in the first dimension
+    GGML_ASSERT(nb10 == ggml_type_size(dst->type));
+    // must either have first dimension large enough to hold a row, or fully contiguous
+    GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+
+        uint32_t i = ir * qk;
+
+        const int64_t i03 = i/(ne00 * ne01 * ne02);
+        const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+        const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+        const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+        const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+        const int64_t i13 = i/(ne10 * ne11 * ne12);
+        const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+        const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+        const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+        const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+        dequantize_row_q(
+                (const void *) ((char *) src0->data + x_offset),
+                     (float *) ((char *)  dst->data + dst_offset), qk);
+    }
+}
+
 static void ggml_compute_forward_dup(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -3993,6 +4136,10 @@ static void ggml_compute_forward_dup(
             } break;
         default:
             {
+                if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
+                    ggml_compute_forward_dup_q(params, dst);
+                    break;
+                }
                 GGML_ABORT("fatal error");
             }
     }
@@ -6691,20 +6838,20 @@ static void ggml_compute_forward_silu_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * grad = dst->src[1];
+    const struct ggml_tensor * grad = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
 
     assert(ggml_is_contiguous_1(grad));
-    assert(ggml_is_contiguous_1(src0));
+    assert(ggml_is_contiguous_1(src1));
     assert(ggml_is_contiguous_1(dst));
-    assert(ggml_are_same_shape(src0, dst));
-    assert(ggml_are_same_shape(src0, grad));
+    assert(ggml_are_same_shape(src1, dst));
+    assert(ggml_are_same_shape(src1, grad));
 
     const int ith = params->ith;
     const int nth = params->nth;
 
-    const int nc = src0->ne[0];
-    const int nr = ggml_nrows(src0);
+    const int nc = src1->ne[0];
+    const int nr = ggml_nrows(src1);
 
     // rows per thread
     const int dr = (nr + nth - 1)/nth;
@@ -6716,7 +6863,7 @@ static void ggml_compute_forward_silu_back_f32(
     for (int i1 = ir0; i1 < ir1; i1++) {
         ggml_vec_silu_backward_f32(nc,
                 (float *) ((char *) dst->data  + i1*( dst->nb[1])),
-                (float *) ((char *) src0->data + i1*(src0->nb[1])),
+                (float *) ((char *) src1->data + i1*(src1->nb[1])),
                 (float *) ((char *) grad->data + i1*(grad->nb[1])));
 
 #ifndef NDEBUG
@@ -6895,7 +7042,7 @@ static void ggml_compute_forward_norm_f32(
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
-    GGML_ASSERT(eps > 0.0f);
+    GGML_ASSERT(eps >= 0.0f);
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -6966,7 +7113,7 @@ static void ggml_compute_forward_rms_norm_f32(
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
 
-    GGML_ASSERT(eps > 0.0f);
+    GGML_ASSERT(eps >= 0.0f);
 
     // TODO: optimize
     for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7018,12 +7165,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
+    const struct ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
 
     GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
 
     GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -7042,8 +7190,8 @@ static void ggml_compute_forward_rms_norm_back_f32(
                 const int64_t i12 = i02;
                 const int64_t i13 = i03;
 
-                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                const float * dz = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
+                const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+                const float * x  = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
 
                 ggml_float sum_xx  = 0.0;
                 ggml_float sum_xdz = 0.0;
@@ -7066,9 +7214,9 @@ static void ggml_compute_forward_rms_norm_back_f32(
                 {
                     // z = rms_norm(x)
                     //
-                    // rms_norm(src0) =
+                    // rms_norm(src1) =
                     //     scale(
-                    //         src0,
+                    //         src1,
                     //         div(
                     //             1,
                     //             sqrt(
@@ -7076,13 +7224,13 @@ static void ggml_compute_forward_rms_norm_back_f32(
                     //                     scale(
                     //                         sum(
                     //                             sqr(
-                    //                                 src0)),
+                    //                                 src1)),
                     //                         (1.0/N)),
                     //                     eps))));
 
                     // postorder:
                     // ## op    args         grad
-                    // 00 param src0         grad[#00]
+                    // 00 param src1         grad[#00]
                     // 01 const 1
                     // 02 sqr   (#00)        grad[#02]
                     // 03 sum   (#02)        grad[#03]
@@ -7159,6 +7307,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
                 // dx := scale(dx, rrms)
                 float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
+                // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
                 ggml_vec_cpy_f32  (ne00, dx, x);
                 // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
                 ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
@@ -7439,6 +7588,7 @@ UseGgmlGemm1:;
     if (src1->type != vec_dot_type) {
         char * wdata = params->wdata;
 
+        const size_t nbw0 = ggml_type_size(vec_dot_type);
         const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
         const size_t nbw2 = nbw1*ne11;
         const size_t nbw3 = nbw2*ne12;
@@ -7446,6 +7596,7 @@ UseGgmlGemm1:;
         assert(params->wsize >= ne13*nbw3);
         GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
+    #if 0
         for (int64_t i13 = 0; i13 < ne13; ++i13) {
             for (int64_t i12 = 0; i12 < ne12; ++i12) {
                 for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
@@ -7455,6 +7606,20 @@ UseGgmlGemm1:;
                 }
             }
         }
+    #else
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                    size_t bs = ggml_blck_size(vec_dot_type);
+                    int64_t ne10_block_start = (ith * ne10/bs) / nth;
+                    int64_t ne10_block_end   = ((ith + 1) * ne10/bs) / nth;
+                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
+                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
+                               (ne10_block_end - ne10_block_start) * bs);
+                }
+            }
+        }
+    #endif
     }
 
     if (ith == 0) {
@@ -7509,7 +7674,7 @@ UseGgmlGemm2:;
     int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
 
     // If the chunking is poor for the number of threads on this setup, scrap the whole plan.  Re-chunk it by thread.
-    //   Also, chunking by thread was measured to have perform better on NUMA systems.  See https://github.com/ggerganov/llama.cpp/pull/6915
+    //   Also, chunking by thread was measured to have perform better on NUMA systems.  See https://github.com/ggml-org/llama.cpp/pull/6915
     //   In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
     if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
         // distribute the thread work across the inner or outer loop based on which one is larger
@@ -7542,7 +7707,6 @@ UseGgmlGemm2:;
         if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) {
             num_rows_per_vec_dot = 1;
         }
-
         ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
 
         if (nth >= nchunk0 * nchunk1) {
@@ -7555,6 +7719,84 @@ UseGgmlGemm2:;
 
 // ggml_compute_forward_mul_mat_id
 
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ids->ne[0]*ids->ne[1] + (i1)]
+
+struct mmid_row_mapping {
+    int32_t i1;
+    int32_t i2;
+};
+
+static void ggml_compute_forward_mul_mat_id_one_chunk(
+    struct ggml_tensor * dst,
+    const struct ggml_tensor * src0,
+    const struct ggml_tensor * src1,
+    const struct ggml_tensor * ids,
+    const int64_t cur_a,
+    const int64_t ir0_start,
+    const int64_t ir0_end,
+    const int64_t ir1_start,
+    const int64_t ir1_end,
+    const char * src0_cur,
+    const struct mmid_row_mapping * matrix_rows,
+    const size_t row_size,
+    const bool src1_cont,
+    const void * wdata) {
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const enum ggml_type type = src0->type;
+
+    ggml_vec_dot_t    const vec_dot      = type_traits_cpu[type].vec_dot;
+    enum ggml_type    const vec_dot_type = type_traits_cpu[type].vec_dot_type;
+
+    const int64_t blck_0 = 16;
+    const int64_t blck_1 = 16;
+
+    float tmp[16];
+
+    for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
+        for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
+            for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ++ir1) {
+                const int64_t _i12 = ir1; // logical row index for this expert
+
+                struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
+                const int id       = row_mapping.i1; // selected expert index
+
+                const int64_t  i11 = id % ne11;
+                const int64_t  i12 = row_mapping.i2; // row index in src1
+
+                const int64_t  i1 = id;  // selected expert index
+                const int64_t  i2 = i12; // row
+
+                // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
+                //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
+                //       the original src1 data pointer, so we should index using the indices directly
+                // TODO: this is a bit of a hack, we should probably have a better way to handle this
+                const char * src1_col = (const char *) wdata +
+                    (src1_cont || src1->type != vec_dot_type
+                    ? (i11      + i12*ne11)*row_size
+                    : (i11*nb11 + i12*nb12));
+
+                float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
+
+                for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
+                    vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
+                }
+
+                memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
+            }
+        }
+    }
+}
+
+static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
+
+    void * ptr = *p;
+    ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
+    *p = (void *) ((char *) ptr + size);
+    return ptr;
+}
+
 static void ggml_compute_forward_mul_mat_id(
         const struct ggml_compute_params * params,
               struct ggml_tensor * dst) {
@@ -7572,7 +7814,6 @@ static void ggml_compute_forward_mul_mat_id(
 
     const bool src1_cont = ggml_is_contiguous(src1);
 
-    ggml_vec_dot_t    const vec_dot         = type_traits_cpu[type].vec_dot;
     enum ggml_type    const vec_dot_type    = type_traits_cpu[type].vec_dot_type;
     ggml_from_float_t const from_float      = type_traits_cpu[vec_dot_type].from_float;
 
@@ -7590,21 +7831,27 @@ static void ggml_compute_forward_mul_mat_id(
     const int n_ids = ids->ne[0]; // n_expert_used
     const int n_as  = ne02;       // n_expert
 
-    char * wdata_src1_end = (src1->type == vec_dot_type) ?
-            (char *) params->wdata :
-            (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
+    void * wdata_cur = params->wdata;
 
-    struct mmid_row_mapping {
-        int32_t i1;
-        int32_t i2;
-    };
+    if (src1->type != vec_dot_type) {
+        incr_ptr_aligned(&wdata_cur, ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
+    }
 
-    int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
-    struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
+    int64_t * matrix_row_counts = // [n_as]
+        incr_ptr_aligned(&wdata_cur, n_as*sizeof(int64_t), sizeof(int64_t));
+
+    struct mmid_row_mapping * matrix_rows = // [n_as][ids->ne[0]*ids->ne[1]]
+        incr_ptr_aligned(&wdata_cur, n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping), sizeof(int64_t));
+
+    char (*atomic_current_chunk)[CACHE_LINE_SIZE] = // [n_as]
+        incr_ptr_aligned(&wdata_cur, CACHE_LINE_SIZE * n_as, CACHE_LINE_SIZE);
+
+    GGML_ASSERT(params->wsize >= (size_t)((char *) wdata_cur - (char *) params->wdata));
 
     if (src1->type != vec_dot_type) {
         char * wdata = params->wdata;
 
+        const size_t nbw0 = ggml_type_size(vec_dot_type);
         const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
         const size_t nbw2 = nbw1*ne11;
         const size_t nbw3 = nbw2*ne12;
@@ -7612,19 +7859,32 @@ static void ggml_compute_forward_mul_mat_id(
         assert(params->wsize >= ne13*nbw3);
         GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
+#if 0
         for (int64_t i13 = 0; i13 < ne13; ++i13) {
-            for (int64_t i12 = 0; i12 < ne12; ++i12) {
-                for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
+            for (int64_t i12 = ith; i12 < ne12; i12 += nth) {
+                for (int64_t i11 = 0; i11 < ne11; ++i11) {
                     from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
                                (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
                                ne10);
                 }
             }
         }
+#else
+        for (int64_t i13 = 0; i13 < ne13; ++i13) {
+            for (int64_t i12 = 0; i12 < ne12; ++i12) {
+                for (int64_t i11 = 0; i11 < ne11; ++i11) {
+                    size_t bs = ggml_blck_size(vec_dot_type);
+                    int64_t ne10_block_start = (ith * ne10/bs) / nth;
+                    int64_t ne10_block_end   = ((ith + 1) * ne10/bs) / nth;
+                    from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + ne10_block_start*bs*nb10),
+                               (void *)               (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1 + ne10_block_start*nbw0),
+                               (ne10_block_end - ne10_block_start) * bs);
+                }
+            }
+        }
+#endif
     }
 
-#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
-
     if (ith == 0) {
         // initialize matrix_row_counts
         memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -7642,9 +7902,14 @@ static void ggml_compute_forward_mul_mat_id(
         }
     }
 
+    // reset current_chunk
+    for (int cur_a = ith; cur_a < n_as; cur_a += nth) {
+        atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
+        *current_chunk_ctr = nth;
+    }
+
     ggml_barrier(params->threadpool);
 
-    // compute each matrix multiplication in sequence
     for (int cur_a = 0; cur_a < n_as; ++cur_a) {
         const int64_t cne1 = matrix_row_counts[cur_a];
 
@@ -7652,84 +7917,64 @@ static void ggml_compute_forward_mul_mat_id(
             continue;
         }
 
-        const char * src0_cur = (const char *) src0->data + cur_a*nb02;
-
-        const void * wdata    = (src1->type == vec_dot_type) ? src1->data : params->wdata;
+        const char * src0_cur = (const char *) src0->data + cur_a * nb02;
+        const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
         const size_t row_size = ggml_row_size(vec_dot_type, ne10);
 
-        const int64_t nr0 = ne01; // src0 rows
-        const int64_t nr1 = cne1; // src1 rows
+        const int64_t nr0 = ne01;
+        const int64_t nr1 = cne1;
 
-        // distribute the thread work across the inner or outer loop based on which one is larger
+        int chunk_size = 16;
+        if (nr0 == 1 || nr1 == 1) {
+            chunk_size = 64;
+        }
 
-        const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
-        const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
+#if defined(__aarch64__)
+        // disable for ARM
+        const bool disable_chunking = true;
+#else
+        // disable for NUMA
+        const bool disable_chunking = ggml_is_numa();
+#endif // defined(__aarch64__)
 
-        const int64_t ith0 = ith % nth0;
-        const int64_t ith1 = ith / nth0;
+        int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
+        int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
 
-        const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
-        const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
+        if (nchunk0 * nchunk1 < nth * 4 || disable_chunking) {
+            nchunk0 = nr0 > nr1 ? nth : 1;
+            nchunk1 = nr0 > nr1 ? 1 : nth;
+        }
 
-        const int64_t ir010 = dr0*ith0;
-        const int64_t ir011 = MIN(ir010 + dr0, nr0);
+        const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
+        const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
 
-        const int64_t ir110 = dr1*ith1;
-        const int64_t ir111 = MIN(ir110 + dr1, nr1);
+        int current_chunk = ith;
 
-        // threads with no work simply yield (not sure if it helps)
-        //if (ir010 >= ir011 || ir110 >= ir111) {
-        //    sched_yield();
-        //    continue;
-        //}
+        atomic_int * current_chunk_ctr = (atomic_int *)(atomic_current_chunk + cur_a);
 
-        // block-tiling attempt
-        const int64_t blck_0 = 16;
-        const int64_t blck_1 = 16;
+        while (current_chunk < nchunk0 * nchunk1) {
+            const int64_t ith0 = current_chunk % nchunk0;
+            const int64_t ith1 = current_chunk / nchunk0;
 
-        // attempt to reduce false-sharing (does not seem to make a difference)
-        float tmp[16];
+            const int64_t ir0_start = dr0 * ith0;
+            const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
 
-        for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
-            for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
-                for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
-                    const int64_t _i12 = ir1; // logical row index for this expert
+            const int64_t ir1_start = dr1 * ith1;
+            const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
 
-                    struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
-                    const int id       = row_mapping.i1; // selected expert index
+            ggml_compute_forward_mul_mat_id_one_chunk(
+                dst, src0, src1, ids, cur_a,
+                ir0_start, ir0_end, ir1_start, ir1_end,
+                src0_cur, matrix_rows, row_size, src1_cont, wdata
+            );
 
-                    const int64_t  i11 = id % ne11;
-                    const int64_t  i12 = row_mapping.i2; // row index in src1
-
-                    const int64_t  i1 = id;  // selected expert index
-                    const int64_t  i2 = i12; // row
-
-                    // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
-                    //       if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
-                    //       the original src1 data pointer, so we should index using the indices directly
-                    // TODO: this is a bit of a hack, we should probably have a better way to handle this
-                    const char * src1_col = (const char *) wdata +
-                        (src1_cont || src1->type != vec_dot_type
-                        ? (i11      + i12*ne11)*row_size
-                        : (i11*nb11 + i12*nb12));
-
-                    float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
-
-                    //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                    //    vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
-                    //}
-
-                    for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
-                        vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
-                    }
-
-                    memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
-                }
+            if (nth >= nchunk0 * nchunk1) {
+                break;
             }
+
+            current_chunk = atomic_fetch_add_explicit(current_chunk_ctr, 1, memory_order_relaxed);
         }
     }
-
-#undef MMID_MATRIX_ROW
 }
 
 // ggml_compute_forward_out_prod
@@ -7750,12 +7995,13 @@ static void ggml_compute_forward_out_prod_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-    GGML_ASSERT(ne0  == ne00);
-    GGML_ASSERT(ne1  == ne10);
-    GGML_ASSERT(ne2  == ne02);
-    GGML_ASSERT(ne02 == ne12);
-    GGML_ASSERT(ne3  == ne13);
-    GGML_ASSERT(ne03 == ne13);
+    GGML_ASSERT(ne0 == ne00);
+    GGML_ASSERT(ne1 == ne10);
+    GGML_ASSERT(ne2 == ne12);
+    GGML_ASSERT(ne3 == ne13);
+
+    GGML_ASSERT(ne2 % ne02 == 0);
+    GGML_ASSERT(ne3 % ne03 == 0);
 
     // we don't support permuted src0 or src1
     GGML_ASSERT(nb00 == sizeof(float));
@@ -7797,6 +8043,10 @@ static void ggml_compute_forward_out_prod_f32(
     const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
     const int64_t blck_1 = 16;
 
+    // dps == dst per src0, used for group query attention
+    const int64_t dps2 = ne2 / ne02;
+    const int64_t dps3 = ne3 / ne03;
+
     for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
         const int64_t bir1 = MIN(bir + blck_1, ir1);
         for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
@@ -7807,8 +8057,8 @@ static void ggml_compute_forward_out_prod_f32(
                 const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
                 const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
 
-                const int64_t i02 = i2;
-                const int64_t i03 = i3;
+                const int64_t i02 = i2 / dps2;
+                const int64_t i03 = i3 / dps3;
 
                 //const int64_t i10 = i1;
                 const int64_t i12 = i2;
@@ -7821,7 +8071,7 @@ static void ggml_compute_forward_out_prod_f32(
 
                     float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
                     float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
 
                     ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
                 }
@@ -7830,7 +8080,7 @@ static void ggml_compute_forward_out_prod_f32(
 
                     float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
                     float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
-                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
+                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
 
                     ggml_vec_mad_f32(ne0, d, s0, *s1);
                 }
@@ -8906,9 +9156,9 @@ static void ggml_compute_forward_soft_max(
 }
 
 
-// ggml_compute_forward_soft_max_back
+// ggml_compute_forward_soft_max_ext_back
 
-static void ggml_compute_forward_soft_max_back_f32(
+static void ggml_compute_forward_soft_max_ext_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
@@ -8921,6 +9171,14 @@ static void ggml_compute_forward_soft_max_back_f32(
     GGML_ASSERT(ggml_are_same_shape(src0, dst));
     GGML_ASSERT(ggml_are_same_shape(src1, dst));
 
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+    GGML_ASSERT(max_bias == 0.0f);
+
     // TODO: handle transposed/permuted matrices
 
     const int ith = params->ith;
@@ -8969,10 +9227,11 @@ static void ggml_compute_forward_soft_max_back_f32(
 
         // linear runtime, no additional memory
         float dot_y_dy = 0;
-        ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
-        ggml_vec_cpy_f32 (nc, dx, dy);
-        ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
-        ggml_vec_mul_f32 (nc, dx, dx, y);
+        ggml_vec_dot_f32  (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
+        ggml_vec_cpy_f32  (nc, dx, dy);
+        ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
+        ggml_vec_mul_f32  (nc, dx, dx, y);
+        ggml_vec_scale_f32(nc, dx, scale);
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
@@ -8983,7 +9242,7 @@ static void ggml_compute_forward_soft_max_back_f32(
     }
 }
 
-static void ggml_compute_forward_soft_max_back(
+static void ggml_compute_forward_soft_max_ext_back(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
@@ -8992,7 +9251,7 @@ static void ggml_compute_forward_soft_max_back(
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_soft_max_back_f32(params, dst);
+                ggml_compute_forward_soft_max_ext_back_f32(params, dst);
             } break;
         default:
             {
@@ -9009,10 +9268,6 @@ static void ggml_compute_forward_clamp_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
 
-    if (params->ith != 0) {
-        return;
-    }
-
     float min;
     float max;
     memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
@@ -9985,9 +10240,10 @@ static void ggml_compute_forward_im2col_back_f32(
         const struct ggml_compute_params * params,
               struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
+    const struct ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
+    const struct ggml_tensor * src1 = dst->src[1]; // convolution kernel
 
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -10009,11 +10265,11 @@ static void ggml_compute_forward_im2col_back_f32(
     const int64_t IH = is_2D ? ne1 : 1;
     const int64_t IW = ne0;
 
-    const int64_t KH = is_2D ? ne01 : 1;
-    const int64_t KW = ne00;
+    const int64_t KH = is_2D ? ne11 : 1;
+    const int64_t KW = ne10;
 
-    const int64_t OH = is_2D ? ne12 : 1;
-    const int64_t OW = ne11;
+    const int64_t OH = is_2D ? ne02 : 1;
+    const int64_t OW = ne01;
 
     int ofs0 = is_2D ? nb3 : nb2;
     int ofs1 = is_2D ? nb2 : nb1;
@@ -10059,9 +10315,9 @@ static void ggml_compute_forward_im2col_back_f32(
                                     continue;
                                 }
 
-                                const float * const src_data = (const float *) src1->data
+                                const float * const grad_in = (const float *) src0->data
                                     + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
-                                grad += src_data[iic*(KH*KW) + ikh*KW + ikw];
+                                grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
                             }
                         }
                         float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
@@ -11856,9 +12112,9 @@ static void ggml_compute_forward_add_rel_pos(
 static void ggml_compute_forward_rwkv_wkv6_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    const int64_t T = dst->src[1]->ne[3];
+    const int64_t T = dst->src[1]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t HEADS = dst->src[1]->ne[2];
+    const int64_t HEADS = dst->src[1]->ne[1];
     const int64_t n_seqs = dst->src[5]->ne[1];
     const int64_t head_size = C / HEADS;
 
@@ -12053,6 +12309,197 @@ static void ggml_compute_forward_rwkv_wkv6(
     }
 }
 
+// ggml_compute_forward_gla
+
+static void ggml_compute_forward_gla_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[1];
+    const int64_t n_seqs = dst->src[4]->ne[1];
+    const int64_t head_size = C / HEADS;
+    const float scale = ggml_get_op_params_f32(dst, 0);
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * k = (float *) dst->src[0]->data;
+    float * v = (float *) dst->src[1]->data;
+    float * q = (float *) dst->src[2]->data;
+    float * g = (float *) dst->src[3]->data;
+
+    size_t t_stride = HEADS * head_size; // Same to C
+
+    size_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    size_t h_stride_2d = head_size * head_size;
+
+    if (ith == 0) {
+        memset(dst_data, 0, T * C * sizeof(float));
+    }
+    ggml_barrier(params->threadpool);
+
+
+    #if defined(__AVX__) && !defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x8
+        #define GGML_F32X_SET1 GGML_F32x8_SET1
+        #define GGML_F32X_LOAD GGML_F32x8_LOAD
+        #define GGML_F32X_STORE GGML_F32x8_STORE
+        #define GGML_F32X_MUL GGML_F32x8_MUL
+        #define GGML_F32X_FMA GGML_F32x8_FMA
+        #define GLA_VECTOR_SIZE 8
+    #elif defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x16
+        #define GGML_F32X_SET1 GGML_F32x16_SET1
+        #define GGML_F32X_LOAD GGML_F32x16_LOAD
+        #define GGML_F32X_STORE GGML_F32x16_STORE
+        #define GGML_F32X_MUL GGML_F32x16_MUL
+        #define GGML_F32X_FMA GGML_F32x16_FMA
+        #define GLA_VECTOR_SIZE 16
+    #elif defined(__ARM_NEON) && defined(__aarch64__)
+        #define GGML_F32X GGML_F32x4
+        #define GGML_F32X_SET1 GGML_F32x4_SET1
+        #define GGML_F32X_LOAD GGML_F32x4_LOAD
+        #define GGML_F32X_STORE GGML_F32x4_STORE
+        #define GGML_F32X_MUL GGML_F32x4_MUL
+        #define GGML_F32X_FMA GGML_F32x4_FMA
+        #define GLA_VECTOR_SIZE 4
+    #endif
+
+    #ifdef GLA_VECTOR_SIZE
+        const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
+
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    // Broadcast scalar values to vectors
+                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
+                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
+
+                    for (int64_t j = 0; j < vec_count; j++) {
+                        size_t base_j = j * GLA_VECTOR_SIZE;
+                        size_t t_h_j_offset = t_h_offset + base_j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+                        // Load x elements at once
+                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+                        // Compute kv = v * k
+                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+                        // Compute temp = prev_state * g + kv
+                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
+
+                        // Update dst: dst += temp * q
+                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
+                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+                        // Update state
+                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
+                    }
+
+                    // Handle remaining elements, this will not be used.
+                    for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val + prev_state_val * g_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+
+    #else
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = prev_state_val * g_val + kv_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+    #endif
+}
+
+
+static void ggml_compute_forward_gla(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gla_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -12346,22 +12793,22 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
 
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * opt0 = dst->src[2];
+    const struct ggml_tensor * grad  = dst->src[0]; // gradient of forward pass output
+    const struct ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
+    const struct ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
 
     GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(opt0));
-    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(src0f));
+    GGML_ASSERT(ggml_is_contiguous(src1f));
+    GGML_ASSERT(ggml_is_contiguous(grad));
+    GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
 
     const int64_t ith = params->ith;
     const int64_t nth = params->nth;
 
     // TODO: handle transposed/permuted matrices
-    const int64_t nc = src0->ne[0];
-    const int64_t nr = ggml_nrows(src0);
+    const int64_t nc = src0f->ne[0];
+    const int64_t nr = ggml_nrows(src0f);
 
     // rows per thread
     const int64_t dr = (nr + nth - 1)/nth;
@@ -12370,12 +12817,12 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
     const int64_t ir0 = dr*ith;
     const int64_t ir1 = MIN(ir0 + dr, nr);
 
-    const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
+    const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
 
     for (int64_t i1 = ir0; i1 < ir1; i1++) {
-        float * ds0 = (float *)((char *) dst->data  + i1*dst->nb[1]);
-        float * s0  = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float * s1  = (float *)((char *) src1->data + i1*src1->nb[1]);
+        float       * ds0 = (float       *)((char       *) dst->data   + i1*dst->nb[1]);
+        const float * s0  = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
+        const float * s1  = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
 
 #ifndef NDEBUG
         for (int64_t i = 0; i < nc; ++i) {
@@ -12388,11 +12835,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
         // soft_max
         float max = -INFINITY;
         ggml_vec_max_f32(nc, &max, s0);
-        ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
+        const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
         assert(sum > 0.0);
         ggml_vec_scale_f32(nc, ds0, 1.0/sum);
 
-        // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
+        // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
         ggml_vec_sub_f32(nc, ds0, ds0, s1);
         ggml_vec_scale_f32(nc, ds0, d_by_nr);
 
@@ -12689,7 +13136,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_SOFT_MAX_BACK:
             {
-                ggml_compute_forward_soft_max_back(params, tensor);
+                ggml_compute_forward_soft_max_ext_back(params, tensor);
             } break;
         case GGML_OP_ROPE:
             {
@@ -12806,6 +13253,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rwkv_wkv6(params, tensor);
             } break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            {
+                ggml_compute_forward_gla(params, tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -13105,6 +13556,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
@@ -13513,14 +13965,19 @@ struct ggml_cplan ggml_graph_plan(
                         cur = 0;
                         const struct ggml_tensor * src0 = node->src[0];
                         const struct ggml_tensor * src1 = node->src[1];
+                        const struct ggml_tensor * ids = node->src[2];
                         const enum ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type;
-                        if (src1->type != vec_dot_type) {
-                            cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
-                        }
                         const int n_as = src0->ne[2];
-                        cur += GGML_PAD(cur, sizeof(int64_t));       // align
-                        cur += n_as * sizeof(int64_t);               // matrix_row_counts
-                        cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
+                        // src1
+                        if (src1->type != vec_dot_type) {
+                            cur += ggml_row_size(vec_dot_type, ggml_nelements(src1)) + sizeof(int64_t);
+                        }
+                        // matrix_row_counts
+                        cur += n_as * sizeof(int64_t) + sizeof(int64_t);
+                        // matrix_rows
+                        cur += n_as*ids->ne[0]*ids->ne[1]*sizeof(struct mmid_row_mapping) + sizeof(int64_t);
+                        // atomic_current_chunk
+                        cur += CACHE_LINE_SIZE*n_as + CACHE_LINE_SIZE;
                     } break;
                 case GGML_OP_OUT_PROD:
                     {
@@ -13530,6 +13987,7 @@ struct ggml_cplan ggml_graph_plan(
                     } break;
                 case GGML_OP_SOFT_MAX:
                 case GGML_OP_ROPE:
+                case GGML_OP_ROPE_BACK:
                     {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
                     } break;
@@ -13640,20 +14098,24 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
         /*.threadpool=*/ tp,
     };
 
-    for (int node_n = 0; node_n < cgraph->n_nodes && !tp->abort; node_n++) {
+    for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) {
         struct ggml_tensor * node = cgraph->nodes[node_n];
 
         ggml_compute_forward(¶ms, node);
 
         if (state->ith == 0 && cplan->abort_callback &&
                 cplan->abort_callback(cplan->abort_callback_data)) {
-            tp->abort = true;
+            atomic_store_explicit(&tp->abort, node_n + 1, memory_order_relaxed);
             tp->ec    = GGML_STATUS_ABORTED;
         }
 
-        ggml_barrier(state->threadpool);
+        if (node_n + 1 < cgraph->n_nodes) {
+            ggml_barrier(state->threadpool);
+        }
     }
 
+    ggml_barrier(state->threadpool);
+
     return 0;
 }
 
@@ -13820,7 +14282,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
         threadpool->current_chunk    = 0;
         threadpool->stop             = false;
         threadpool->pause            = tpp->paused;
-        threadpool->abort            = false;
+        threadpool->abort            = -1;
         threadpool->workers          = NULL;
         threadpool->n_threads_max    = tpp->n_threads;
         threadpool->n_threads_cur    = tpp->n_threads;
@@ -13899,7 +14361,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
         threadpool->cgraph           = cgraph;
         threadpool->cplan            = cplan;
         threadpool->current_chunk    = 0;
-        threadpool->abort            = false;
+        threadpool->abort            = -1;
         threadpool->ec               = GGML_STATUS_SUCCESS;
     }
 
@@ -14098,6 +14560,14 @@ int ggml_cpu_has_vsx(void) {
 #endif
 }
 
+int ggml_cpu_has_vxe(void) {
+#if defined(__VXE__) || defined(__VXE2__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_neon(void) {
 #if defined(__ARM_ARCH) && defined(__ARM_NEON)
     return ggml_arm_arch_features.has_neon;
@@ -14138,6 +14608,14 @@ int ggml_cpu_get_sve_cnt(void) {
 #endif
 }
 
+int ggml_cpu_has_sme(void) {
+#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SME)
+    return ggml_arm_arch_features.has_sme;
+#else
+    return 0;
+#endif
+}
+
 void ggml_cpu_init(void) {
     // needed to initialize f16 tables
     {
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
index f11399cc..a84203f2 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.cpp
@@ -14,6 +14,10 @@
 #include "ggml-cpu-hbm.h"
 #endif
 
+#ifdef GGML_USE_CPU_KLEIDIAI
+#include "kleidiai/kleidiai.h"
+#endif
+
 #if defined(__APPLE__)
 #include 
 #include 
@@ -39,6 +43,12 @@ std::vector& ggml_backend_cpu_get_extra_buffers_type
         }
 #endif
 
+#ifdef GGML_USE_CPU_KLEIDIAI
+        if (ggml_backend_cpu_kleidiai_buffer_type()) {
+            bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type());
+        }
+#endif
+
 #ifdef GGML_USE_CPU_AARCH64
         if (ggml_backend_cpu_aarch64_buffer_type()) {
             bufts.push_back(ggml_backend_cpu_aarch64_buffer_type());
@@ -284,14 +294,14 @@ struct ggml_backend_cpu_device_context {
                         &hKey) == ERROR_SUCCESS) {
             DWORD cpu_brand_size = 0;
             if (RegQueryValueExA(hKey,
-                                TEXT("ProcessorNameString"),
+                                "ProcessorNameString",
                                 NULL,
                                 NULL,
                                 NULL,
                                 &cpu_brand_size) == ERROR_SUCCESS) {
                 description.resize(cpu_brand_size);
                 if (RegQueryValueExA(hKey,
-                                    TEXT("ProcessorNameString"),
+                                    "ProcessorNameString",
                                     NULL,
                                     NULL,
                                     (LPBYTE)&description[0], // NOLINT
@@ -403,12 +413,21 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
                 op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
         case GGML_OP_MUL_MAT:
             return src1->type == GGML_TYPE_F32 || src1->type == ggml_get_type_traits_cpu(src0->type)->vec_dot_type;
-        case GGML_OP_ROPE_BACK:
-            return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
+        case GGML_OP_SOFT_MAX_BACK: {
+            if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
+                return false;
+            }
+            float max_bias = 0.0f;
+
+            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
+
+            return max_bias == 0.0f;
+        }
         case GGML_OP_IM2COL_BACK:
             return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
         case GGML_OP_OUT_PROD:
-            return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
+            return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
+                src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
         default:
             return true;
     }
@@ -525,19 +544,22 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
         if (ggml_cpu_has_dotprod()) {
             features.push_back({ "DOTPROD", "1" });
         }
-        if (ggml_cpu_has_matmul_int8()) {
-            features.push_back({ "MATMUL_INT8", "1" });
-        }
         if (ggml_cpu_get_sve_cnt() > 0) {
             static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
             features.push_back({ "SVE_CNT", sve_cnt.c_str() });
         }
+        if (ggml_cpu_has_sme()) {
+            features.push_back({ "SME", "1" });
+        }
         if (ggml_cpu_has_riscv_v()) {
             features.push_back({ "RISCV_V", "1" });
         }
         if (ggml_cpu_has_vsx()) {
             features.push_back({ "VSX", "1" });
         }
+        if (ggml_cpu_has_vxe()) {
+            features.push_back({ "VXE", "1" });
+        }
         if (ggml_cpu_has_wasm_simd()) {
             features.push_back({ "WASM_SIMD", "1" });
         }
@@ -553,6 +575,9 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
     #ifdef GGML_USE_OPENMP
         features.push_back({ "OPENMP", "1" });
     #endif
+    #ifdef GGML_USE_CPU_KLEIDIAI
+        features.push_back({ "KLEIDIAI", "1" });
+    #endif
     #ifdef GGML_USE_CPU_AARCH64
         features.push_back({ "AARCH64_REPACK", "1" });
     #endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
index 3f260ce5..e0482c59 100644
--- a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp
@@ -54,6 +54,7 @@
 #include "ggml-quants.h"
 
 #include 
+#include 
 
 #ifdef _MSC_VER
 #define NOINLINE __declspec(noinline)
@@ -1052,6 +1053,704 @@ class tinyBLAS_Q0_AVX {
       } \
    } \
 
+template 
+class tinyBLAS_Q0_PPC {
+  public:
+    tinyBLAS_Q0_PPC(int64_t k,
+                const TA *A, int64_t lda,
+                const TB *B, int64_t ldb,
+                TC *C, int64_t ldc,
+                int ith, int nth)
+        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+    }
+
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
+    }
+
+  private:
+
+    template
+    inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
+       for (int I = 0; I < RM; I++) {
+          for (int J = 0; J < RN; J++) {
+             *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
+          }
+       }
+    }
+
+    template
+    inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array& comparray, vector float* vs, vector float* fin_res) {
+       vector signed int vec_C[4];
+       vector float CA[4] = {0};
+       vector float res[4] = {0};
+       __builtin_mma_disassemble_acc(vec_C, ACC);
+       for (int i = 0; i < 4; i++) {
+          CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
+          res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
+          fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
+       }
+    }
+
+    template
+    void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+        int64_t i, j;
+        TA *aoffset = NULL;
+        VA *vecOffset = NULL;
+        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
+        VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
+        VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
+        VB t1, t2, t3, t4, t5, t6, t7, t8;
+        vector unsigned char xor_vector;
+        uint8_t flip_vec = 0x80;
+        xor_vector = vec_splats(flip_vec);
+        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
+        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
+
+        aoffset = const_cast(a);
+        vecOffset = vec;
+        j = (rows >> 3);
+        if (j > 0) {
+            do {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            aoffset4 = aoffset3 + lda;
+            aoffset5 = aoffset4 + lda;
+            aoffset6 = aoffset5 + lda;
+            aoffset7 = aoffset6 + lda;
+            aoffset8 = aoffset7 + lda;
+            aoffset += 8 * lda;
+
+            i = (cols >> 3);
+            if (i > 0) {
+               do {
+                    C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
+                    C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
+                    C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
+                    C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
+                    C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
+                    C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
+                    C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
+                    C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
+
+                    __builtin_vsx_disassemble_pair(c1, &C1);
+                    __builtin_vsx_disassemble_pair(c2, &C2);
+                    __builtin_vsx_disassemble_pair(c3, &C3);
+                    __builtin_vsx_disassemble_pair(c4, &C4);
+                    __builtin_vsx_disassemble_pair(c5, &C5);
+                    __builtin_vsx_disassemble_pair(c6, &C6);
+                    __builtin_vsx_disassemble_pair(c7, &C7);
+                    __builtin_vsx_disassemble_pair(c8, &C8);
+
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    t1 = vec_perm(c5[0], c6[0], swiz1);
+                    t2 = vec_perm(c5[0], c6[0], swiz2);
+                    t3 = vec_perm(c7[0], c8[0], swiz1);
+                    t4 = vec_perm(c7[0], c8[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+128);
+                    vec_xst(t6, 0, vecOffset+144);
+                    vec_xst(t7, 0, vecOffset+160);
+                    vec_xst(t8, 0, vecOffset+176);
+
+                    t1 = vec_perm(c5[1], c6[1], swiz1);
+                    t2 = vec_perm(c5[1], c6[1], swiz2);
+                    t3 = vec_perm(c7[1], c8[1], swiz1);
+                    t4 = vec_perm(c7[1], c8[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+192);
+                    vec_xst(t6, 0, vecOffset+208);
+                    vec_xst(t7, 0, vecOffset+224);
+                    vec_xst(t8, 0, vecOffset+240);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    aoffset4 += lda;
+                    aoffset5 += lda;
+                    aoffset6 += lda;
+                    aoffset7 += lda;
+                    aoffset8 += lda;
+                    vecOffset += 256;
+                    i--;
+               } while(i > 0);
+            }
+            j--;
+        } while(j > 0);
+    }
+
+    if (rows & 4) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            aoffset4 = aoffset3 + lda;
+            aoffset += 4 * lda;
+
+        i = (cols >> 3);
+            if (i > 0) {
+               do {
+                    C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
+                    C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
+                    C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
+                    C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
+
+                    __builtin_vsx_disassemble_pair(c1, &C1);
+                    __builtin_vsx_disassemble_pair(c2, &C2);
+                    __builtin_vsx_disassemble_pair(c3, &C3);
+                    __builtin_vsx_disassemble_pair(c4, &C4);
+
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    aoffset4 += lda;
+                    vecOffset += 128;
+                    i--;
+               } while(i > 0);
+            }
+        }
+        if (rows & 3) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            i = (cols >> 3);
+        if (i > 0) {
+                do {
+                    switch(rows) {
+                        case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
+                                __builtin_vsx_disassemble_pair(c3, &C3);
+                        case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
+                                __builtin_vsx_disassemble_pair(c2, &C2);
+                        case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
+                                __builtin_vsx_disassemble_pair(c1, &C1);
+                                break;
+                    }
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    vecOffset += 128;
+                    i--;
+               } while(i > 0);
+            }
+        }
+    }
+
+    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t mc, nc, mp, np;
+        int m_rem = MIN(m - m0, 8);
+        int n_rem = MIN(n - n0, 8);
+        // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
+        // issues. After resolving them, below code will be enabled.
+        /*if (m_rem >= 16 && n_rem >= 8) {
+            mc = 16;
+            nc = 8;
+            gemm<16,8>(m0, m, n0, n);
+        } else if(m_rem >= 8 && n_rem >= 16) {
+            mc = 8;
+            nc = 16;
+            gemm<8,16>(m0, m, n0, n);
+        }*/
+        if (m_rem >= 8 && n_rem >= 8) {
+            mc = 8;
+            nc = 8;
+            gemm<8,8>(m0, m, n0, n);
+        } else if (m_rem >= 4 && n_rem >= 8) {
+            mc = 4;
+            nc = 8;
+            gemm<4,8>(m0, m, n0, n);
+        } else if (m_rem >= 8 && n_rem >= 4) {
+            mc = 8;
+            nc = 4;
+            gemm<8,4>(m0, m, n0, n);
+        } else if (m_rem >= 4 && n_rem >= 4) {
+            mc = 4;
+            nc = 4;
+            gemm_small<4, 4>(m0, m, n0, n);
+        } else if ((m_rem < 4) && (n_rem > 4)) {
+            nc = 4;
+            switch(m_rem) {
+                case 1:
+                    mc = 1;
+                    gemm_small<1, 4>(m0, m, n0, n);
+                    break;
+                case 2:
+                    mc = 2;
+                    gemm_small<2, 4>(m0, m, n0, n);
+                    break;
+                case 3:
+                    mc = 3;
+                    gemm_small<3, 4>(m0, m, n0, n);
+                    break;
+                default:
+                    return;
+            }
+        } else if ((m_rem > 4) && (n_rem < 4)) {
+            mc = 4;
+            switch(n_rem) {
+                case 1:
+                    nc = 1;
+                    gemm_small<4, 1>(m0, m, n0, n);
+                    break;
+                case 2:
+                    nc = 2;
+                    gemm_small<4, 2>(m0, m, n0, n);
+                    break;
+                case 3:
+                    nc = 3;
+                    gemm_small<4, 3>(m0, m, n0, n);
+                    break;
+                default:
+                    return;
+            }
+        } else {
+            switch((m_rem << 4) | n_rem) {
+                case 0x43:
+                    mc = 4;
+                    nc = 3;
+                    gemm_small<4, 3>(m0, m, n0, n);
+                    break;
+                case 0x42:
+                    mc = 4;
+                    nc = 2;
+                    gemm_small<4, 2>(m0, m, n0, n);
+                    break;
+                case 0x41:
+                    mc = 4;
+                    nc = 1;
+                    gemm_small<4, 1>(m0, m, n0, n);
+                    break;
+                case 0x34:
+                    mc = 3;
+                    nc = 4;
+                    gemm_small<3, 4>(m0, m, n0, n);
+                    break;
+                case 0x33:
+                    mc = 3;
+                    nc = 3;
+                    gemm_small<3, 3>(m0, m, n0, n);
+                    break;
+                case 0x32:
+                    mc = 3;
+                    nc = 2;
+                    gemm_small<3, 2>(m0, m, n0, n);
+                    break;
+                case 0x31:
+                    mc = 3;
+                    nc = 1;
+                    gemm_small<3, 1>(m0, m, n0, n);
+                    break;
+                case 0x24:
+                    mc = 2;
+                    nc = 4;
+                    gemm_small<2, 4>(m0, m, n0, n);
+                    break;
+                case 0x23:
+                    mc = 2;
+                    nc = 3;
+                    gemm_small<2, 3>(m0, m, n0, n);
+                    break;
+                case 0x22:
+                    mc = 2;
+                    nc = 2;
+                    gemm_small<2, 2>(m0, m, n0, n);
+                    break;
+                case 0x21:
+                    mc = 2;
+                    nc = 1;
+                    gemm_small<2, 1>(m0, m, n0, n);
+                    break;
+                case 0x14:
+                    mc = 1;
+                    nc = 4;
+                    gemm_small<1, 4>(m0, m, n0, n);
+                    break;
+                case 0x13:
+                    mc = 1;
+                    nc = 3;
+                    gemm_small<1, 3>(m0, m, n0, n);
+                    break;
+                case 0x12:
+                    mc = 1;
+                    nc = 2;
+                    gemm_small<1, 2>(m0, m, n0, n);
+                    break;
+                case 0x11:
+                    mc = 1;
+                    nc = 1;
+                    gemm_small<1, 1>(m0, m, n0, n);
+                    break;
+                default:
+                    return;
+            }
+        }
+        mp = m0 + (m - m0) / mc * mc;
+        np = n0 + (n - n0) / nc * nc;
+        mnpack(mp, m, n0, np);
+        mnpack(m0, m, np, n);
+    }
+
+    void KERNEL_4x8(int64_t ii, int64_t jj) {
+        vec_t vec_A[8], vec_B[16] = {0};
+        acc_t acc_0, acc_1;
+        std::array comparray;
+        vector float fin_res[8] = {0};
+        vector float vs[8] = {0};
+        for (int l = 0; l < k; l++) {
+            __builtin_mma_xxsetaccz(&acc_0);
+            __builtin_mma_xxsetaccz(&acc_1);
+            packNormal((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
+            for(int x = 0; x < 8; x++) {
+                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
+            }
+            for (int I = 0; I<4; I++) {
+                for (int J = 0; J<4; J++) {
+                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+                    *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
+                }
+            }
+            auto aoffset = A+(ii*lda)+l;
+            for (int i = 0; i < 4; i++) {
+                comparray[i] = 0;
+                int ca = 0;
+                const int8_t *at = aoffset->qs;
+                for (int j = 0; j < 32; j++)
+                    ca += (int)*at++;
+                comparray[i] = ca;
+                aoffset += lda;
+            }
+            compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
+            compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
+        }
+        save_res<4, 4>(ii, jj, 0, fin_res);
+        save_res<4, 4>(ii, jj+4, 4, fin_res);
+    }
+
+    void KERNEL_8x4(int64_t ii, int64_t jj) {
+        vec_t vec_A[16], vec_B[8] = {0};
+        acc_t acc_0, acc_1;
+        std::array comparray;
+        vector float fin_res[8] = {0};
+        vector float vs[8] = {0};
+        for (int l = 0; l < k; l++) {
+            __builtin_mma_xxsetaccz(&acc_0);
+            __builtin_mma_xxsetaccz(&acc_1);
+            packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
+            for(int x = 0; x < 8; x++) {
+                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
+            }
+            for (int I = 0; I<8; I++) {
+                for (int J = 0; J<4; J++) {
+                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+                }
+            }
+            auto aoffset = A+(ii*lda)+l;
+            for (int i = 0; i < 8; i++) {
+                comparray[i] = 0;
+                int ca = 0;
+                const int8_t *at = aoffset->qs;
+                for (int j = 0; j < 32; j++)
+                    ca += (int)*at++;
+                comparray[i] = ca;
+                aoffset += lda;
+            }
+            compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
+            compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
+        }
+        save_res<4, 4>(ii, jj, 0, fin_res);
+        save_res<4, 4>(ii+4, jj, 4, fin_res);
+    }
+
+    void KERNEL_8x8(int64_t ii, int64_t jj) {
+        vec_t vec_A[16], vec_B[16] = {0};
+        acc_t acc_0, acc_1, acc_2, acc_3;
+        std::array comparray;
+        vector float fin_res[16] = {0};
+        vector float vs[16] = {0};
+        for (int l = 0; l < k; l++) {
+            __builtin_mma_xxsetaccz(&acc_0);
+            __builtin_mma_xxsetaccz(&acc_1);
+            __builtin_mma_xxsetaccz(&acc_2);
+            __builtin_mma_xxsetaccz(&acc_3);
+            packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
+            for(int x = 0; x < 8; x++) {
+                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
+                __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
+            }
+            for (int I = 0; I<8; I++) {
+                for (int J = 0; J<4; J++) {
+                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+                    *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
+                }
+            }
+            auto aoffset = A+(ii*lda)+l;
+            for (int i = 0; i < 8; i++) {
+                comparray[i] = 0;
+                int ca = 0;
+                const int8_t *at = aoffset->qs;
+                for (int j = 0; j < 32; j++)
+                    ca += (int)*at++;
+                comparray[i] = ca;
+                aoffset += lda;
+            }
+            compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
+            compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
+            compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
+            compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
+        }
+        save_res<4, 4>(ii, jj, 0, fin_res);
+        save_res<4, 4>(ii+4, jj, 4, fin_res);
+        save_res<4, 4>(ii, jj+4, 8, fin_res);
+        save_res<4, 4>(ii+4, jj+4, 12, fin_res);
+    }
+
+    template
+    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        vec_t vec_A[8], vec_B[8] = {0};
+        vector signed int vec_C[4];
+        acc_t acc_0;
+
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            std::array comparray;
+            vector float res[4] = {0};
+            vector float fin_res[4] = {0};
+            vector float vs[4] = {0};
+            vector float CA[4] = {0};
+            __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
+            __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
+            for (int l = 0; l < k; l++) {
+                __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
+                __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
+                __builtin_mma_xxsetaccz(&acc_0);
+                packNormal((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+                packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
+                for(int x = 0; x < 8; x+=4) {
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
+                }
+                for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d));
+                    }
+                }
+                __builtin_mma_disassemble_acc(vec_C, &acc_0);
+                auto aoffset = A+(ii*lda)+l;
+                for (int i = 0; i < RM; i++) {
+                    comparray[i] = 0;
+                    int ca = 0;
+                    const int8_t *at = aoffset->qs;
+                    for (int j = 0; j < 32; j++)
+                        ca += (int)*at++;
+                    comparray[i] = ca;
+                    aoffset += lda;
+                }
+
+                for (int i = 0; i < RM; i++) {
+                    CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
+                    res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
+                    fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
+                }
+            }
+            save_res(ii, jj, 0, fin_res);
+        }
+    }
+
+    template
+    inline void kernel(int64_t ii, int64_t jj) {
+       if constexpr(RM == 4 && RN == 8) {
+          KERNEL_4x8(ii,jj);
+       } else if constexpr(RM == 8 && RN == 4) {
+          KERNEL_8x4(ii,jj);
+       } else if constexpr(RM == 8 && RN == 8) {
+          KERNEL_8x8(ii,jj);
+       } else {
+          static_assert(false, "RN/RM values not supported");
+       }
+    }
+
+    template 
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            kernel(ii, jj);
+        }
+    }
+
+    const TA *const A;
+    const TB *const B;
+    TC *C;
+    TA *At;
+    TB *Bt;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
+    const int ith;
+    const int nth;
+};
+
 template 
 class tinyBLAS_PPC {
   public:
@@ -1071,13 +1770,17 @@ class tinyBLAS_PPC {
 
     void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
 
-    void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
+    template
+    void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
         int64_t i, j;
-        float *aoffset = NULL, *boffset = NULL;
-        float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
-        float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
-
-        aoffset = const_cast(a);
+        TA *aoffset = NULL, *boffset = NULL;
+        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
+        VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
+        VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
+        VA t1, t2, t3, t4, t5, t6, t7, t8;
+        aoffset = const_cast(a);
         boffset = vec;
         j = (rows >> 3);
         if (j > 0) {
@@ -1093,9 +1796,6 @@ class tinyBLAS_PPC {
                 aoffset += 8 * lda;
                 i = (cols >> 3);
                 if (i > 0) {
-                    __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
-                    vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
-                    vector float t1, t2, t3, t4, t5, t6, t7, t8;
                     do {
                         C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
                         C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1175,21 +1875,19 @@ class tinyBLAS_PPC {
                     } while(i > 0);
                 }
                 if (cols & 4) {
-                    vector float c1, c2, c3, c4, c5, c6, c7, c8;
-                    vector float t1, t2, t3, t4, t5, t6, t7, t8;
-                    c1 = vec_xl(0, aoffset1);
-                    c2 = vec_xl(0, aoffset2);
-                    c3 = vec_xl(0, aoffset3);
-                    c4 = vec_xl(0, aoffset4);
-                    c5 = vec_xl(0, aoffset5);
-                    c6 = vec_xl(0, aoffset6);
-                    c7 = vec_xl(0, aoffset7);
-                    c8 = vec_xl(0, aoffset8);
+                    c1[0] = vec_xl(0, aoffset1);
+                    c2[0] = vec_xl(0, aoffset2);
+                    c3[0] = vec_xl(0, aoffset3);
+                    c4[0] = vec_xl(0, aoffset4);
+                    c5[0] = vec_xl(0, aoffset5);
+                    c6[0] = vec_xl(0, aoffset6);
+                    c7[0] = vec_xl(0, aoffset7);
+                    c8[0] = vec_xl(0, aoffset8);
 
-                    t1 = vec_mergeh(c1, c2);
-                    t2 = vec_mergeh(c3, c4);
-                    t3 = vec_mergeh(c5, c6);
-                    t4 = vec_mergeh(c7, c8);
+                    t1 = vec_mergeh(c1[0], c2[0]);
+                    t2 = vec_mergeh(c3[0], c4[0]);
+                    t3 = vec_mergeh(c5[0], c6[0]);
+                    t4 = vec_mergeh(c7[0], c8[0]);
                     t5 = vec_xxpermdi(t1, t2, 0);
                     t6 = vec_xxpermdi(t3, t4, 0);
                     t7 = vec_xxpermdi(t1, t2, 3);
@@ -1199,10 +1897,10 @@ class tinyBLAS_PPC {
                     vec_xst(t7, 0, boffset+8);
                     vec_xst(t8, 0, boffset+12);
 
-                    t1 = vec_mergel(c1, c2);
-                    t2 = vec_mergel(c3, c4);
-                    t3 = vec_mergel(c5, c6);
-                    t4 = vec_mergel(c7, c8);
+                    t1 = vec_mergel(c1[0], c2[0]);
+                    t2 = vec_mergel(c3[0], c4[0]);
+                    t3 = vec_mergel(c5[0], c6[0]);
+                    t4 = vec_mergel(c7[0], c8[0]);
                     t5 = vec_xxpermdi(t1, t2, 0);
                     t6 = vec_xxpermdi(t3, t4, 0);
                     t7 = vec_xxpermdi(t1, t2, 3);
@@ -1224,9 +1922,6 @@ class tinyBLAS_PPC {
             aoffset += 4 * lda;
             i = (cols >> 3);
             if (i > 0) {
-                __vector_pair C1, C2, C3, C4;
-                vector float c1[2], c2[2], c3[2], c4[2];
-                vector float t1, t2, t3, t4, t5, t6, t7, t8;
                 do {
                     C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
                     C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1273,22 +1968,20 @@ class tinyBLAS_PPC {
             }
 
             if (cols & 4) {
-                vector float c1, c2, c3, c4;
-                vector float t1, t2, t3, t4;
-                c1 = vec_xl(0, aoffset1);
-                c2 = vec_xl(0, aoffset2);
-                c3 = vec_xl(0, aoffset3);
-                c4 = vec_xl(0, aoffset4);
+                c1[0] = vec_xl(0, aoffset1);
+                c2[0] = vec_xl(0, aoffset2);
+                c3[0] = vec_xl(0, aoffset3);
+                c4[0] = vec_xl(0, aoffset4);
 
-                t1 = vec_mergeh(c1, c2);
-                t2 = vec_mergeh(c3, c4);
+                t1 = vec_mergeh(c1[0], c2[0]);
+                t2 = vec_mergeh(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset);
                 vec_xst(t4, 0, boffset+4);
 
-                t1 = vec_mergel(c1, c2);
-                t2 = vec_mergel(c3, c4);
+                t1 = vec_mergel(c1[0], c2[0]);
+                t2 = vec_mergel(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset+8);
@@ -1300,21 +1993,19 @@ class tinyBLAS_PPC {
             aoffset2 = aoffset1 + lda;
             aoffset3 = aoffset2 + lda;
             if (cols & 4) {
-                vector float c1, c2, c3, c4 = {0};
-                vector float t1, t2, t3, t4;
-                c1 = vec_xl(0, aoffset1);
-                c2 = vec_xl(0, aoffset2);
-                c3 = vec_xl(0, aoffset3);
+                c1[0] = vec_xl(0, aoffset1);
+                c2[0] = vec_xl(0, aoffset2);
+                c3[0] = vec_xl(0, aoffset3);
 
-                t1 = vec_mergeh(c1, c2);
-                t2 = vec_mergeh(c3, c4);
+                t1 = vec_mergeh(c1[0], c2[0]);
+                t2 = vec_mergeh(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset);
                 vec_xst(t4, 0, boffset+4);
 
-                t1 = vec_mergel(c1, c2);
-                t2 = vec_mergel(c3, c4);
+                t1 = vec_mergel(c1[0], c2[0]);
+                t2 = vec_mergel(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset+8);
@@ -1322,14 +2013,13 @@ class tinyBLAS_PPC {
             }
         }
     }
-
     void KERNEL_4x4(int64_t ii, int64_t jj) {
         vec_t vec_A[4], vec_B[4], vec_C[4];
         acc_t acc_0;
         __builtin_mma_xxsetaccz(&acc_0);
         for (int l = 0; l < k; l+=4) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
@@ -1344,8 +2034,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_0);
         __builtin_mma_xxsetaccz(&acc_1);
         for (int64_t l = 0; l < k; l+=4) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -1365,8 +2055,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_0);
         __builtin_mma_xxsetaccz(&acc_1);
         for (int64_t l = 0; l < k; l+=4) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -1388,8 +2078,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_2);
         __builtin_mma_xxsetaccz(&acc_3);
         for (int l = 0; l < k; l+=8) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
             for(int x = 0; x < 16; x+=2) {
                 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
                 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
@@ -1572,15 +2262,15 @@ class tinyBLAS_PPC {
             vec_t vec_A[4], vec_B[4];
             for (int l=0; l= 4 && RM == 1) {
-                    float* a = const_cast(A+(ii)*lda+l);
-                    READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+                    TA* a = const_cast(A+(ii)*lda+l);
+                    packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
                     vec_A[0] = (vec_t)vec_xl(0,a);
-                    vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
-                    vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
-                    vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
+                    vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
+                    vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
+                    vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
                 } else {
-                    READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
-                    READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+                    packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
+                    packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
                 }
                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -1590,7 +2280,7 @@ class tinyBLAS_PPC {
             __builtin_mma_disassemble_acc(vec_C, &acc_0);
             for (int I = 0; I < RM; I++) {
                 for (int J = 0; J < RN; J++) {
-                    *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
+                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
                 }
             }
        }
@@ -1813,6 +2503,20 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             params->ith, params->nth};
         tb.matmul(m, n);
         return true;
+
+#elif defined(__MMA__)
+        if (n < 8 && n != 4)
+           return false;
+        if (m < 8 && m != 4)
+           return false;
+        tinyBLAS_Q0_PPC tb{
+            k, (const block_q8_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            params->ith, params->nth};
+        tb.matmul(m, n);
+        return true;
+
 #else
         return false;
 #endif
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
index 14761650..96bd5a0b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt
@@ -7,7 +7,7 @@ if (CUDAToolkit_FOUND)
 
     if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
         # native == GPUs available at build time
-        # 52     == Maxwell, lowest CUDA 12 standard
+        # 50     == Maxwell, lowest CUDA 12 standard
         # 60     == P100, FP16 CUDA intrinsics
         # 61     == Pascal, __dp4a instruction (per-byte integer dot product)
         # 70     == V100, FP16 tensor cores
@@ -15,9 +15,9 @@ if (CUDAToolkit_FOUND)
         if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
             set(CMAKE_CUDA_ARCHITECTURES "native")
         elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
-            set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
+            set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75;80")
         else()
-            set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
+            set(CMAKE_CUDA_ARCHITECTURES "50;61;70;75;80")
         endif()
     endif()
     message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
@@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND)
     list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
 
     file(GLOB   GGML_SOURCES_CUDA "*.cu")
-    file(GLOB   SRCS "template-instances/fattn-wmma*.cu")
+    file(GLOB   SRCS "template-instances/fattn-mma*.cu")
     list(APPEND GGML_SOURCES_CUDA ${SRCS})
     file(GLOB   SRCS "template-instances/mmq*.cu")
     list(APPEND GGML_SOURCES_CUDA ${SRCS})
@@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
         add_compile_definitions(GGML_CUDA_NO_VMM)
     endif()
 
+    if (NOT GGML_CUDA_FA)
+        add_compile_definitions(GGML_CUDA_NO_FA)
+    endif()
+
     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
         add_compile_definitions(GGML_CUDA_F16)
     endif()
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
index c7b6be4e..ce4b9cfb 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
@@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
 
 template 
 static __global__ void k_repeat_back(
-    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2) {
+    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+    const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
 
-    const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
-    const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
-    const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
+    const int64_t tid0  = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
+    const int64_t tid1  = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
+    const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
+    const int64_t tid2  = tid23 % ne2;
+    const int64_t tid3  = tid23 / ne2;
 
     if (tid0 >= ne0) {
         return;
     }
 
     T sum = 0;
-    for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
-        for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
-            for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
-                sum += src[i2*ne01*ne00 + i1*ne00 + i0];
+    for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
+        for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
+            for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
+                for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
+                    sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
+                }
             }
         }
     }
-    dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
+    dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
 }
 
 template
@@ -274,12 +279,14 @@ struct bin_bcast_cuda {
 
 template 
 static void repeat_back_cuda(
-    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
-    const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
+    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+    const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
 
     const dim3 block_dims(WARP_SIZE, 1, 1);
-    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
-    k_repeat_back<<>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
+    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
+    k_repeat_back<<>>
+        (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
 }
 
 template
@@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const ggml_tensor * src0 = dst->src[0];
 
     GGML_ASSERT(src0->type == dst->type);
-    GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(ggml_is_contiguous(dst));
     GGML_ASSERT(ggml_can_repeat(dst, src0));
 
     cudaStream_t stream = ctx.stream();
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    GGML_ASSERT(src0->ne[3] == 1);
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-    const int64_t ne2 = dst->ne[2];
-    GGML_ASSERT(dst->ne[3] == 1);
+    GGML_ASSERT(ne2*ne3 <= (1 << 15));
+
+    const size_t ts = ggml_type_size(src0->type);
+    const size_t s00 = nb00 / ts;
+    const size_t s01 = nb01 / ts;
+    const size_t s02 = nb02 / ts;
+    const size_t s03 = nb03 / ts;
 
     switch (dst->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
             float       * dst_d  = (float       *) dst->data;
-            repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
+            repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
         } break;
         default: {
             GGML_ASSERT(false);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
index 2c0a5622..adf0d3ec 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
@@ -41,29 +41,78 @@
 #define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
 #define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
 
-#define GGML_CUDA_CC_PASCAL     600
-#define GGML_CUDA_CC_DP4A       610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
-#define GGML_CUDA_CC_VOLTA      700
-#define GGML_CUDA_CC_TURING     750
-#define GGML_CUDA_CC_AMPERE     800
-#define GGML_CUDA_CC_OFFSET_AMD 1000000
+#define GGML_CUDA_CC_PASCAL       600
+#define GGML_CUDA_CC_DP4A         610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define GGML_CUDA_CC_VOLTA        700
+#define GGML_CUDA_CC_TURING       750
+#define GGML_CUDA_CC_AMPERE       800
+#define GGML_CUDA_CC_ADA_LOVELACE 890
+#define GGML_CUDA_CC_OFFSET_AMD   0x1000000
 
 // GCN/CNDA, wave size is 64
-#define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 803)  // Tonga, Fiji, Polaris, minimum for fast fp16
-#define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 900)  // Vega56/64, minimum for fp16 dual issue
-#define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 906)  // MI50/Radeon VII, minimum for dp4a
-#define GGML_CUDA_CC_CDNA       (GGML_CUDA_CC_OFFSET_AMD + 908)  // MI100, minimum for MFMA, acc registers
-#define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 910)  // MI210, minimum acc register renameing
-#define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 942)  // MI300
+#define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 0x803)  // Tonga, Fiji, Polaris, minimum for fast fp16
+#define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 0x900)  // Vega56/64, minimum for fp16 dual issue
+#define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 0x906)  // MI50/Radeon VII, minimum for dp4a
+#define GGML_CUDA_CC_CDNA       (GGML_CUDA_CC_OFFSET_AMD + 0x908)  // MI100, minimum for MFMA, acc registers
+#define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x910)  // MI210, minimum acc register renameing
+#define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x942)  // MI300
 
 // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
-#define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
-#define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
-#define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
+#define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
+#define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
+#define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
+
+#define GGML_CUDA_CC_IS_RDNA(cc)  (cc >= GGML_CUDA_CC_RDNA1)
+#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
+#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
+#define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
+#define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
 
 #define GGML_CUDA_CC_QY1        210
 #define GGML_CUDA_CC_QY2        220
 
+#ifdef __CUDA_ARCH_LIST__
+constexpr bool ggml_cuda_has_arch_impl(int) {
+    return false;
+}
+
+template
+constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
+    return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
+}
+
+constexpr bool ggml_cuda_has_arch(const int arch) {
+    return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
+}
+
+constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
+    if (cur == 0) {
+        GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
+    }
+    return cur;
+}
+
+template
+constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
+    if (first <= arch && first > cur) {
+        return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
+    } else {
+        return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
+    }
+}
+
+constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
+    return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
+}
+#else
+static int ggml_cuda_highest_compiled_arch(const int arch) {
+    return arch;
+}
+#endif // __CUDA_ARCH_LIST__
+
+// ---------------------------------------------------------------------------------------------------------
+
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
 #if defined(_MSC_VER)
@@ -117,11 +166,11 @@ static const char * cu_get_error_str(CUresult err) {
 #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
 #endif
 
-#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
+#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
 #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
 #else
 #define GGML_CUDA_ASSUME(x)
-#endif // CUDART_VERSION >= 11100
+#endif // CUDART_VERSION >= 11010
 
 #ifdef GGML_CUDA_F16
 typedef half dfloat; // dequantize float
@@ -131,6 +180,10 @@ typedef float dfloat; // dequantize float
 typedef float2 dfloat2;
 #endif // GGML_CUDA_F16
 
+#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
+#define GGML_USE_VMM
+#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
+
 #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 #define FP16_AVAILABLE
 #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
@@ -144,23 +197,55 @@ typedef float2 dfloat2;
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
-#define INT8_MMA_AVAILABLE
+#define NEW_MMA_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 
-#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
-#define FLASH_ATTN_AVAILABLE
-#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#define CP_ASYNC_AVAILABLE
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
-static constexpr bool fast_fp16_available(const int cc) {
+#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#define FLASH_ATTN_AVAILABLE
+#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+
+static bool fp16_available(const int cc) {
+    return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
+}
+
+static bool fast_fp16_available(const int cc) {
+    return fp16_available(cc) && cc != 610;
+}
+
+// To be used for feature selection of external libraries, e.g. cuBLAS.
+static bool fast_fp16_hardware_available(const int cc) {
     return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
 }
 
-static constexpr bool fp16_mma_available(const int cc) {
+// Any FP16 tensor core instructions are available for ggml code.
+static bool fp16_mma_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
+}
+
+// To be used for feature selection of external libraries, e.g. cuBLAS.
+static bool fp16_mma_hardware_available(const int cc) {
     return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
 }
 
-static constexpr bool int8_mma_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
+// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
+static bool new_mma_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
+}
+
+static bool cp_async_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+}
+
+static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+    return __AMDGCN_WAVEFRONT_SIZE;
+#else
+    return 32;
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 }
 
 [[noreturn]]
@@ -186,53 +271,46 @@ static __device__ void no_device_code(
 #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
 #endif // __CUDA_ARCH__
 
+template
 static __device__ __forceinline__ int warp_reduce_sum(int x) {
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
     return __reduce_add_sync(0xffffffff, x);
 #else
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, width);
     }
     return x;
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 }
 
+template
 static __device__ __forceinline__ float warp_reduce_sum(float x) {
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, width);
     }
     return x;
 }
 
+template
 static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
-        a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
     }
     return a;
 }
 
+template
 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #ifdef FP16_AVAILABLE
-
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
-        reinterpret_cast(a.x) +=  __low2half(a_other);
-        reinterpret_cast(a.y) += __high2half(a_other);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
     }
     return a;
-#else
-#pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
-    }
-    return a;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
 #else
     NO_DEVICE_CODE;
@@ -240,10 +318,11 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #endif // FP16_AVAILABLE
 }
 
+template
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
     }
     return x;
 }
@@ -265,35 +344,34 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
 }
 
 static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-
-#if CUDART_VERSION >= CUDART_HMAX
+#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000
+    return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
+#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX
     return __hmax2(a, b);
-#else
+#elif !defined(GGML_USE_HIP)
     half2 ret;
     reinterpret_cast(ret.x) = __float2half(fmaxf( __low2float(a),  __low2float(b)));
     reinterpret_cast(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
     return ret;
-#endif // CUDART_VERSION >= CUDART_HMAX
-
 #else
     GGML_UNUSED(a);
     GGML_UNUSED(b);
     NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif
 }
 
+template
 static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
 #pragma unroll
-   for (int offset = 16; offset > 0; offset >>= 1) {
-       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
+   for (int offset = width/2; offset > 0; offset >>= 1) {
+       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
    }
    return x;
 #else
    GGML_UNUSED(x);
    NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
 }
 
 #if CUDART_VERSION < CUDART_HMASK
@@ -333,13 +411,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
 
 #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
-#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
     return __dp4a(a, b, c);
-#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
     const int8_t * a8 = (const int8_t *) &a;
     const int8_t * b8 = (const int8_t *) &b;
     return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
-#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
 
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 }
@@ -516,6 +594,7 @@ struct ggml_cuda_device_info {
         bool    vmm;                // virtual memory support
         size_t  vmm_granularity;    // granularity of virtual memory
         size_t  total_vram;
+        int     warp_size;          // Number of threads in a dispatch
     };
 
     cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
@@ -588,7 +667,7 @@ struct ggml_tensor_extra_gpu {
 };
 
 
-#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
+#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
 #define USE_CUDA_GRAPH
 #endif
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu b/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu
index 5eb9f08d..aafbaf80 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/concat.cu
@@ -124,7 +124,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
           uint64_t   nb1,
           uint64_t   nb2,
           uint64_t   nb3){
-    static_assert(dim >= 0 && dim <= 3, "dim must be between 0 and 3");
+    static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]");
 
     const int64_t i3 = blockIdx.z;
     const int64_t i2 = blockIdx.y;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
index 5b0dface..795b720d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/convert.cu
@@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
         case GGML_TYPE_Q5_1:
             return dequantize_block_cuda;
         case GGML_TYPE_Q8_0:
-            if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
+            if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
                 return dequantize_block_q8_0_f16_cuda;
             }
             return dequantize_block_cuda;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cp-async.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/cp-async.cuh
new file mode 100644
index 00000000..ecb65999
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cp-async.cuh
@@ -0,0 +1,46 @@
+// Simplified API for asynchronous data loading.
+
+#include "common.cuh"
+
+// Copies data from global to shared memory, cg == cache global.
+// Both the src and dst pointers must be aligned to 16 bit.
+// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
+// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
+// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
+template 
+static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
+    static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
+#ifdef CP_ASYNC_AVAILABLE
+#if CUDART_VERSION >= 11040
+    if (preload == 256) {
+        asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    } else if (preload == 128) {
+        asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    } else if (preload == 64) {
+        asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    } else
+#endif // CUDART_VERSION >= 11040
+    {
+        asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
+            : : "r"(dst), "l"(src));
+    }
+#else
+    GGML_UNUSED(dst);
+    GGML_UNUSED(src);
+    NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
+
+// Makes each thread wait until its asynchronous data copies are done.
+// This does NOT provide any additional synchronization.
+// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
+static __device__ __forceinline__ void cp_async_wait_all() {
+#ifdef CP_ASYNC_AVAILABLE
+    asm volatile("cp.async.wait_all;");
+#else
+    NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
index 54c0f66d..cca2bee0 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
@@ -1,4 +1,5 @@
 #include "cpy.cuh"
+#include "dequantize.cuh"
 
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
 
@@ -82,13 +83,14 @@ static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
 }
 
 static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
-    const block_q8_0 * xi = (const block_q8_0 *) cxi;
-    float * dsti = (float *) cdsti;
+    float * cdstf = (float *)(cdsti);
 
-    const float d = (float)xi->d;
-
-    for (int j = 0; j < QK8_0; j++) {
-       dsti[j] = xi->qs[j] * d;
+#pragma unroll
+    for (int j = 0; j < QK8_0; j += 2) {
+        dfloat2 dq;
+        dequantize_q8_0(cxi, 0, j, dq);
+        *(cdstf + j) = dq.x;
+        *(cdstf + j + 1) = dq.y;
     }
 }
 
@@ -225,6 +227,18 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
     memcpy(dsti->qh, &qh, sizeof(qh));
 }
 
+template
+static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
+    float * cdstf = (float *)(cdsti);
+
+#pragma unroll
+    for (int j = 0; j < qk/2; j++) {
+        dfloat2 dq;
+        dequant(cxi, 0, j, dq);
+        *(cdstf + j) = dq.x;
+        *(cdstf + j + qk/2) = dq.y;
+    }
+}
 
 static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
     if (x <= val[0]) return 0;
@@ -387,6 +401,19 @@ static void ggml_cpy_f32_q4_0_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_q4_0_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02,
+    const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12,
+    const int nb10, const int nb11, const int nb12, const int nb13,
+    cudaStream_t stream) {
+    const int num_blocks = ne;
+    cpy_q_f32, QK4_0><<>>(
+        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_q4_1_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -398,6 +425,19 @@ static void ggml_cpy_f32_q4_1_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_q4_1_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02,
+    const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12,
+    const int nb10, const int nb11, const int nb12, const int nb13,
+    cudaStream_t stream) {
+    const int num_blocks = ne;
+    cpy_q_f32, QK4_1><<>>(
+        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_q5_0_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -409,6 +449,19 @@ static void ggml_cpy_f32_q5_0_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_q5_0_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02,
+    const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12,
+    const int nb10, const int nb11, const int nb12, const int nb13,
+    cudaStream_t stream) {
+    const int num_blocks = ne;
+    cpy_q_f32, QK5_0><<>>(
+        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_q5_1_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -420,6 +473,19 @@ static void ggml_cpy_f32_q5_1_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+static void ggml_cpy_q5_1_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02,
+    const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12,
+    const int nb10, const int nb11, const int nb12, const int nb13,
+    cudaStream_t stream) {
+    const int num_blocks = ne;
+    cpy_q_f32, QK5_1><<>>(
+        cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+        ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_iq4_nl_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -488,14 +554,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
         ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
         ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
         ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
         ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
+            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
         ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
         ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
         ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
@@ -524,14 +601,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
         return (void*) cpy_q_f32;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
         return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_q_f32, QK4_0>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
         return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_q_f32, QK4_1>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
         return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_q_f32, QK5_0>;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
         return (void*) cpy_f32_q;
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
         return (void*) cpy_f32_q;
+    } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
+        return (void*) cpy_q_f32, QK5_1>;
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
         return (void*) cpy_f32_f16;
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ml/backend/ggml/ggml/src/ggml-cuda/cross-entropy-loss.cu
index ed09406a..0ce4afbb 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/cross-entropy-loss.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/cross-entropy-loss.cu
@@ -5,95 +5,89 @@
 #include 
 #include 
 
-static __global__ void cross_entropy_loss_f32(const float * logits, const float * labels, float * dst, const int nclasses, const int k) {
-    const int warp_id = threadIdx.x / WARP_SIZE;
-    const int lane_id = threadIdx.x % WARP_SIZE;
-    const int i0 = blockDim.x*blockIdx.x + warp_id*WARP_SIZE;
+template 
+static __global__ void cross_entropy_loss_f32(
+        const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
+    extern __shared__ float tmp[];
 
-    const int ne_tmp = WARP_SIZE*nclasses;
-
-    extern __shared__ float tmp_all[];
-    float * tmp_logits = tmp_all + (2*warp_id + 0)*ne_tmp;
-    float * tmp_labels = tmp_all + (2*warp_id + 1)*ne_tmp;
-
-    // Each warp first loads ne_tmp logits/labels into shared memory:
-    for (int i = lane_id; i < ne_tmp; i += WARP_SIZE) {
-        const int ig = i0*nclasses + i; // ig == i global
-
-        tmp_logits[i] = ig < k*nclasses ? logits[ig] : 0.0f;
-        tmp_labels[i] = ig < k*nclasses ? labels[ig] : 0.0f;
-    }
-
-    // Each thread in the warp then calculates the cross entropy loss for a single row.
-    // TODO: pad in order to avoid shared memory bank conflicts.
+    logits += int64_t(blockIdx.x)*nclasses;
+    labels += int64_t(blockIdx.x)*nclasses;
 
     // Find maximum for softmax:
-    float max = -INFINITY;
-    for (int i = 0; i < nclasses; ++i) {
-        max = fmaxf(max, tmp_logits[lane_id*nclasses + i]);
+    float max_logit = -INFINITY;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float val = logits[i];
+        max_logit = fmaxf(max_logit, val);
+
+        if (use_shared) {
+            tmp[i] = val;
+        }
     }
+    max_logit = warp_reduce_max(max_logit);
 
     // Calculate log(softmax(logits)) which is just logits - max:
     float sum = 0.0f;
-    for (int i = 0; i < nclasses; ++i) {
-        float val = tmp_logits[lane_id*nclasses + i] - max;
-        sum += expf(val);
-        tmp_logits[lane_id*nclasses + i] = val;
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float logit_i = use_shared ? tmp[i] : logits[i];
+        sum += expf(logit_i - max_logit);
     }
+    sum = warp_reduce_sum(sum);
     sum = logf(sum);
 
     // log(exp(logits - max) / sum) = (logits - max) - log(sum)
     float loss = 0.0f;
-    for (int i = 0; i < nclasses; ++i) {
-        loss += (tmp_logits[lane_id*nclasses + i] - sum) * tmp_labels[lane_id*nclasses + i];
+    for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+        const float logit_i = use_shared ? tmp[i] : logits[i];
+        loss += (logit_i - max_logit - sum) * labels[i];
     }
     loss = -warp_reduce_sum(loss) / (float)k;
 
-    __syncthreads();
-
-    if (lane_id == 0) {
-        tmp_all[warp_id] = loss;
-    }
-
-    __syncthreads();
-
-    if (warp_id != 0) {
-        return;
-    }
-
-    loss = lane_id < CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE/WARP_SIZE ? tmp_all[lane_id] : 0.0f;
-    loss = warp_reduce_sum(loss);
-
-    if (lane_id != 0) {
+    if (threadIdx.x != 0) {
         return;
     }
 
     dst[blockIdx.x] = loss;
 }
 
-static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) {
+template 
+static __global__ void cross_entropy_loss_back_f32(
+        const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
+        float * __restrict__ dst, const int nclasses) {
     extern __shared__ float tmp[];
 
+    logits += int64_t(blockIdx.x)*nclasses;
+    labels += int64_t(blockIdx.x)*nclasses;
+    dst    += int64_t(blockIdx.x)*nclasses;
+
     float maxval = -INFINITY;
     for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
-        const float val = logits[blockIdx.x*nclasses + i];
+        const float val = logits[i];
         maxval = fmaxf(maxval, val);
-        tmp[i] = val;
+
+        if (use_shared) {
+            tmp[i] = val;
+        }
     }
     maxval = warp_reduce_max(maxval);
 
     float sum = 0.0f;
     for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
-        const float val = expf(tmp[i] - maxval);
+        const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
         sum += val;
-        tmp[i] = val;
+
+        if (use_shared) {
+            tmp[i] = val;
+        } else {
+            dst[i] = val;
+        }
     }
     sum = warp_reduce_sum(sum);
     const float sm_scale = 1.0f/sum;
 
-    const float d_by_nrows = *loss/gridDim.x;
+    const float d_by_nrows = *grad/gridDim.x;
     for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
-        dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows;
+        const float val = use_shared ? tmp[i] : dst[i];
+        dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
     }
 }
 
@@ -119,48 +113,77 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t stream = ctx.stream();
 
-    const dim3 blocks_dim(CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
-    const dim3 blocks_num((nrows + CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE - 1) / CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE, 1, 1);
-    const int shmem = 2*CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE*ne00*sizeof(float);
+    const dim3 blocks_dim(WARP_SIZE, 1, 1);
+    const dim3 blocks_num(nrows, 1, 1);
+    const size_t nbytes_shared = ne00*sizeof(float);
+
+    const int    id    = ggml_cuda_get_device();
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
 
     ggml_cuda_pool_alloc dst_tmp(pool, blocks_num.x);
 
-    cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+    if (nbytes_shared <= smpbo) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+        if (!shared_memory_limit_raised[id]) {
+            CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
+            shared_memory_limit_raised[id] = true;
+        }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+    } else {
+        cross_entropy_loss_f32<<>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+    }
+    CUDA_CHECK(cudaGetLastError());
 
     // Combine results from individual blocks:
     sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
 }
 
 void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const ggml_tensor * src1 = dst->src[1];
-    const ggml_tensor * opt0 = dst->src[2];
+    const ggml_tensor * grad  = dst->src[0];
+    const ggml_tensor * src0f = dst->src[1];
+    const ggml_tensor * src1f = dst->src[2];
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(opt0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0f->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1f->type == GGML_TYPE_F32);
+    GGML_ASSERT( grad->type == GGML_TYPE_F32);
+    GGML_ASSERT(  dst->type == GGML_TYPE_F32);
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(src1));
-    GGML_ASSERT(ggml_is_contiguous(opt0));
+    GGML_ASSERT(ggml_is_scalar(grad));
+    GGML_ASSERT(ggml_is_contiguous(src0f));
+    GGML_ASSERT(ggml_is_contiguous(src1f));
     GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, src1));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
+    GGML_ASSERT(ggml_are_same_shape(src0f, dst));
 
-    const int64_t ne00  = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    const int64_t ne00  = src0f->ne[0];
+    const int64_t nrows = ggml_nrows(src0f);
 
-    const float * src0_d = (const float *) src0->data;
-    const float * src1_d = (const float *) src1->data;
-    const float * opt0_d = (const float *) opt0->data;
-    float       * dst_d  = (float       *) dst->data;
+    const float * grad_d  = (const float *) grad->data;
+    const float * src0f_d = (const float *) src0f->data;
+    const float * src1f_d = (const float *) src1f->data;
+    float       * dst_d   = (float       *) dst->data;
 
     cudaStream_t stream = ctx.stream();
 
     const dim3 blocks_dim(WARP_SIZE, 1, 1);
     const dim3 blocks_num(nrows, 1, 1);
-    const int shmem = ne00*sizeof(float);
+    const size_t nbytes_shared = ne00*sizeof(float);
 
-    cross_entropy_loss_back_f32<<>>(src0_d, src1_d, opt0_d, dst_d, ne00);
+    const int    id    = ggml_cuda_get_device();
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+
+    if (nbytes_shared <= smpbo) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+        if (!shared_memory_limit_raised[id]) {
+            CUDA_CHECK(cudaFuncSetAttribute(cross_entropy_loss_back_f32, cudaFuncAttributeMaxDynamicSharedMemorySize, smpbo));
+            shared_memory_limit_raised[id] = true;
+        }
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
+        cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+    } else {
+        cross_entropy_loss_back_f32<<>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+    }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
index ee9752da..7b9566fb 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
@@ -516,6 +516,96 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
+template // D == head size
+__launch_bounds__(D, 1)
+static __global__ void flash_attn_stream_k_fixup(
+        float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
+    constexpr int ncols = ncols1*ncols2;
+
+    const int bidx0 = blockIdx.x;
+    const int j     = blockIdx.y;
+    const int c     = blockIdx.z;
+    const int jc    = j*ncols2 + c;
+    const int tid   = threadIdx.x;
+
+    const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
+
+    const int iter_k = ne11 / FATTN_KQ_STRIDE;
+    const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+
+    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+
+    const bool did_not_have_any_data   = kbc0 == kbc0_stop;
+    const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
+    const bool did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
+    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
+        return;
+    }
+
+    const int channel = kbc0 / (iter_k*iter_j);
+    const int jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
+
+    if (jt*ncols1 + j >= ne01) {
+        return;
+    }
+
+    dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
+
+    // Load the partial result that needs a fixup:
+    float dst_val = 0.0f;
+    float max_val = 0.0f;
+    float rowsum  = 0.0f;
+    {
+        dst_val = *dst;
+
+        const float2 tmp = dst_fixup[bidx0*ncols + jc];
+        max_val = tmp.x;
+        rowsum  = tmp.y;
+    }
+
+    // Iterate over previous blocks and compute the combined results.
+    // All CUDA blocks that get here must have a previous block that needs a fixup.
+    int bidx = bidx0 - 1;
+    int kbc_stop = kbc0;
+    while(true) {
+        const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+        if (kbc == kbc_stop) { // Did not have any data.
+            bidx--;
+            kbc_stop = kbc;
+            continue;
+        }
+
+        const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
+
+        const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
+
+        // Scale the current and new value accumulators depending on the max. values.
+        const float max_val_new = fmaxf(max_val, tmp.x);
+
+        const float diff_val = max_val - max_val_new;
+        const float diff_add = tmp.x   - max_val_new;
+
+        const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
+        const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
+
+        dst_val = scale_val*dst_val + scale_add*dst_add;
+        rowsum  = scale_val*rowsum  + scale_add*tmp.y;
+
+        max_val = max_val_new;
+
+        // If this block started in a previous tile we are done and don't need to combine additional partial results.
+        if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
+            break;
+        }
+        bidx--;
+        kbc_stop = kbc;
+    }
+
+    // Write back final result:
+    *dst = dst_val / rowsum;
+}
+
 template // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
@@ -581,11 +671,14 @@ static void on_no_fattn_vec_case(const int D) {
     }
 }
 
-template 
+// parallel_blocks == 0 is stream-k decomposition
+template 
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
-    const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
+    const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
 ) {
+    constexpr int ncols = ncols1 * ncols2;
+
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
@@ -603,20 +696,25 @@ void launch_fattn(
 
     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
 
+    GGML_ASSERT(Q->ne[3] == 1);
+
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
+    const int id  = ggml_cuda_get_device();
+    const int cc  = ggml_cuda_info().devices[id].cc;
+    const int nsm = ggml_cuda_info().devices[id].nsm;
 
     ggml_cuda_pool_alloc   K_f16(pool);
     ggml_cuda_pool_alloc   V_f16(pool);
     ggml_cuda_pool_alloc  dst_tmp(pool);
     ggml_cuda_pool_alloc dst_tmp_meta(pool);
 
-    char * K_data = (char *) K->data;
+    const char * K_data = (const char *) K->data;
     size_t nb11 = K->nb[1];
     size_t nb12 = K->nb[2];
     size_t nb13 = K->nb[3];
 
-    char * V_data = (char *) V->data;
+    const char * V_data = (const char *) V->data;
     size_t nb21 = V->nb[1];
     size_t nb22 = V->nb[2];
     size_t nb23 = V->nb[3];
@@ -649,39 +747,61 @@ void launch_fattn(
         nb23 = nb23*bs*sizeof(half)/ts;
     }
 
-    if (parallel_blocks > 1) {
-        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
-        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
-    }
+    const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
+    const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
     const dim3 block_dim(WARP_SIZE, nwarps, 1);
-    const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
-    const int  shmem = 0;
+    dim3 blocks_num;
+    if (parallel_blocks == 0) {
+        // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
+        const int max_blocks = 2*nsm;
+        const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
+        const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
+
+        const int nblocks_stream_k = max_blocks;
+
+        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
+
+        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
+        blocks_num.y = 1;
+        blocks_num.z = 1;
+
+        dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
+    } else {
+        blocks_num.x = parallel_blocks*ntiles_x;
+        blocks_num.y = Q->ne[2];
+        blocks_num.z = Q->ne[3];
+
+        if (parallel_blocks > 1) {
+            dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+            dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+        }
+    }
 
     float scale         = 1.0f;
     float max_bias      = 0.0f;
     float logit_softcap = 0.0f;
 
-    memcpy(&scale,         (float *) KQV->op_params + 0, sizeof(float));
-    memcpy(&max_bias,      (float *) KQV->op_params + 1, sizeof(float));
-    memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
+    memcpy(&scale,         (const float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (const float *) KQV->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
     if (logit_softcap != 0.0f) {
         scale /= logit_softcap;
     }
 
     const uint32_t n_head      = Q->ne[2];
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+    const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
 
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-    fattn_kernel<<>>(
+    fattn_kernel<<>>(
         (const char *) Q->data,
         K_data,
         V_data,
         mask ? ((const char *) mask->data) : nullptr,
-        (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+        (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@@ -693,16 +813,22 @@ void launch_fattn(
     );
     CUDA_CHECK(cudaGetLastError());
 
-    if ((parallel_blocks) == 1) {
-        return;
+    if constexpr (parallel_blocks == 0) {
+        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+            const dim3 block_dim_combine(D, 1, 1);
+            const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
+
+            flash_attn_stream_k_fixup
+                <<>>
+                ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
+        }
+    } else if constexpr (parallel_blocks > 1) {
+        const dim3 block_dim_combine(D, 1, 1);
+        const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+
+        flash_attn_combine_results
+            <<>>
+            (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
     }
-
-    const dim3 block_dim_combine(D, 1, 1);
-    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
-    const int  shmem_combine = 0;
-
-    flash_attn_combine_results
-        <<>>
-        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
     CUDA_CHECK(cudaGetLastError());
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
new file mode 100644
index 00000000..718ee540
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -0,0 +1,1021 @@
+#include "common.cuh"
+#include "cp-async.cuh"
+#include "mma.cuh"
+#include "fattn-common.cuh"
+
+using namespace ggml_cuda_mma;
+
+typedef tile<16,  8, half2> tile_A;
+typedef tile< 8,  8, half2> tile_B;
+typedef tile<16,  8, half2> tile_B_16;
+typedef tile<16,  8, float> tile_C_KQ;
+typedef tile<16, 16, float> tile_C_KQ_16;
+typedef tile<16,  4, half2> tile_C_VKQ;
+typedef tile<16,  8, half2> tile_C_VKQ_16;
+
+template
+static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
+        const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
+    constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+
+    // If cp.async is available, load up to the highest power of 2 in D asynchronously:
+#ifdef CP_ASYNC_AVAILABLE
+    static_assert(D >= 64 && D < 512, "bad D");
+    constexpr int k0_sync_start = D/2 < 64 ? 32 : (D/2 < 128 ? 64 : 128);
+
+    const unsigned int tile_KV_32 = __cvta_generic_to_shared(tile_KV);
+
+    constexpr int preload = 64;
+    constexpr int h2_per_chunk = 16/sizeof(half2);
+    constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
+    constexpr int stride_i = WARP_SIZE / chunks_per_row;
+#pragma unroll
+    for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
+        const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
+        const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
+
+        cp_async_cg_16(tile_KV_32 + (i*D2_padded + k)*sizeof(half2), KV + i*stride_KV + k);
+    }
+#else
+    constexpr int k0_sync_start = 0;
+#endif // CP_ASYNC_AVAILABLE
+    static_assert(k0_sync_start % WARP_SIZE == 0, "bad k0_sync_start");
+
+    // If D is not a power of 2, the rest is loaded synchronously.
+    // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
+    static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds");
+#pragma unroll
+    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+        const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
+        const int k0_stop  =                                         D/2 - (D/2) % (1*stride_k);
+        const int stride_i = WARP_SIZE / stride_k;
+
+        if (k0_start == k0_stop || k0_stop <= k0_sync_start) {
+            continue;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
+            const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+#pragma unroll
+            for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                tile_KV[i*D2_padded + k] = KV[i*stride_KV + k];
+            }
+        }
+    }
+}
+
+template
+static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
+        const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
+    static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter");
+#ifdef CP_ASYNC_AVAILABLE
+    constexpr int preload = KQ_per_iter * sizeof(half);
+    constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter;
+    constexpr int stride_j = nwarps * cols_per_warp;
+
+    const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask);
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
+        const int j = j0 + threadIdx.y*cols_per_warp +
+            (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8));
+
+        if (j0 + stride_j > ncols1 && j >= ncols1) {
+            break;
+        }
+
+        const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8));
+
+        cp_async_cg_16(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
+    }
+#else
+    constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter;
+    constexpr int stride_j = nwarps * cols_per_warp;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
+        const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2));
+
+        if (j0 + stride_j > ncols1 && j >= ncols1) {
+            break;
+        }
+
+        const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2);
+
+        tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i];
+    }
+#endif // CP_ASYNC_AVAILABLE
+}
+
+template
+static __device__ __forceinline__ void flash_attn_ext_f16_iter(
+        const float2 * const __restrict__ Q_f2,
+        const half2  * const __restrict__ K_h2,
+        const half2  * const __restrict__ V_h2,
+        const half2  * const __restrict__ mask_h2,
+        float2       * const __restrict__ dstk,
+        float2       * const __restrict__ dstk_fixup,
+        const float scale,
+        const float slope,
+        const float logit_softcap,
+        const int ne01,
+        const int ne02,
+        const int stride_KV,
+        const int stride_mask,
+        const int jt,
+        half2        * const __restrict__ tile_K,
+        half2        * const __restrict__ tile_V,
+        half2        * const __restrict__ tile_mask,
+        const tile_B * const __restrict__ Q_B,
+        tile_C_VKQ   * const __restrict__ VKQ_C,
+        float        * const __restrict__ KQ_max,
+        float        * const __restrict__ KQ_rowsum,
+        const int kb0) {
+#ifdef NEW_MMA_AVAILABLE
+    constexpr int cols_per_warp   = ntiles * tile_B::I;
+    constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
+    constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int D2_padded       = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+
+    const int k_VKQ_0 = kb0 * KQ_per_iter;
+    tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles];
+
+    // Use wide variants of tiles if ntiles >= 2.
+    tile_B_16     * Q_B_16   = (tile_B_16     *) Q_B;
+    tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
+    tile_C_KQ_16  * KQ_C_16  = (tile_C_KQ_16  *) KQ_C;
+
+#ifdef CP_ASYNC_AVAILABLE
+    cp_async_wait_all();
+    __syncthreads();
+    flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
+#else
+    if (ncols2 > 1 || mask_h2) {
+        flash_attn_ext_f16_load_mask(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
+    }
+    flash_attn_ext_f16_load_tile(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
+    __syncthreads();
+#endif // CP_ASYNC_AVAILABLE
+
+    // Calculate tile of KQ:
+#pragma unroll
+    for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) {
+        const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
+#pragma unroll
+        for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
+            tile_A K_A;
+            load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
+            if (ntiles == 1) {
+                mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
+            } else {
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+                    // Wide version of KQ_C is column-major => swap A and B.
+                    mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
+                }
+            }
+        }
+    }
+
+#ifndef CP_ASYNC_AVAILABLE
+    __syncthreads(); // Only needed if tile_K == tile_V.
+#endif // CP_ASYNC_AVAILABLE
+
+    if (use_logit_softcap) {
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+        for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) {
+#pragma unroll
+            for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
+            }
+        }
+    }
+
+    float KQ_max_new[cols_per_thread];
+#pragma unroll
+    for (int col = 0; col < cols_per_thread; ++col) {
+        KQ_max_new[col] = KQ_max[col];
+    }
+    float KQ_rowsum_add[cols_per_thread] = {0.0f};
+
+    if (ntiles == 1) {
+        if (ncols2 > 1 || mask_h2) {
+#pragma unroll
+            for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) {
+                const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
+#pragma unroll
+                for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                    const int i = i0 + tile_C_KQ::get_i(l);
+                    const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
+
+                    KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
+                        __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]);
+                }
+            }
+        }
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
+#pragma unroll
+            for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
+            }
+        }
+
+        // Values per KQ column are spread across 8 threads, does not need full warp reduce:
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+            for (int offset = 16; offset >= 4; offset >>= 1) {
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+            }
+        }
+
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
+#pragma unroll
+            for (int l = 0; l < tile_C_KQ::ne; ++l) {
+                KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
+
+                KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
+            }
+        }
+    } else { // ntiles > 1
+        if (ncols2 > 1 || mask_h2) {
+#pragma unroll
+            for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) {
+                const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+#pragma unroll
+                    for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
+                        const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
+                        const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
+
+                        const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]);
+                        const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
+                        KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
+                        KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
+                    }
+                }
+            }
+        }
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+        static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
+#pragma unroll
+            for (int t = 0; t < ntiles/2; ++t) {
+#pragma unroll
+                for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
+                    const int KQ_index = 2*t + (l/2) % 2;
+                    KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
+                }
+            }
+        }
+
+        // Values per KQ column are spread across 4 threads, does not need full warp reduce:
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+            for (int offset = 2; offset >= 1; offset >>= 1) {
+                KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+            }
+        }
+
+        static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size");
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
+#pragma unroll
+            for (int t = 0; t < ntiles/2; ++t) {
+#pragma unroll
+                for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
+                    const int KQ_index = 2*t + (l/2) % 2;
+
+                    KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
+
+                    KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
+                }
+            }
+        }
+    }
+
+    {
+        float KQ_max_scale[cols_per_thread];
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+            KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
+            KQ_max[col] = KQ_max_new[col];
+
+            // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+            KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
+        }
+
+        if (ntiles == 1) {
+            const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
+#pragma unroll
+            for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
+#pragma unroll
+                for (int l = 0; l < tile_C_VKQ::ne; ++l) {
+                    VKQ_C[i].x[l] *= KQ_max_scale_h2;
+                }
+            }
+        } else {
+#pragma unroll
+            for (int col = 0; col < cols_per_thread; ++col) {
+                const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+                for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
+#pragma unroll
+                    for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
+                        VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
+                    }
+                }
+            }
+        }
+    }
+
+    // Convert KQ C tiles into B tiles for VKQ calculation:
+    tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles];
+    tile_B_16 * B_16 = (tile_B_16 *) B;
+    static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size");
+    if (ntiles == 1) {
+#pragma unroll
+        for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) {
+            B[k] = get_transposed(get_half2(KQ_C[k]));
+        }
+    } else {
+        for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) {
+#pragma unroll
+            for (int t = 0; t < ntiles/2; ++t) {
+                B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
+            }
+        }
+    }
+
+#ifdef CP_ASYNC_AVAILABLE
+    // Preload K tile for next iteration:
+    cp_async_wait_all();
+    __syncthreads();
+    if (!last_iter) {
+        if (ncols2 > 1 || mask_h2) {
+            flash_attn_ext_f16_load_mask(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask);
+        }
+        flash_attn_ext_f16_load_tile(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV);
+    }
+#else
+    flash_attn_ext_f16_load_tile(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
+    __syncthreads();
+#endif // CP_ASYNC_AVAILABLE
+
+    // Calculate VKQ tile:
+#pragma unroll
+    for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
+        static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size");
+#pragma unroll
+        for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) {
+            const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
+
+            tile_A A;
+            load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
+            if (ntiles == 1) {
+                mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
+            } else {
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+                    // Wide version of VKQ_C is column-major => swap A and B.
+                    mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
+                }
+            }
+        }
+    }
+
+#ifndef CP_ASYNC_AVAILABLE
+    __syncthreads(); // Only needed if tile_K == tile_V.
+#endif // CP_ASYNC_AVAILABLE
+
+#else
+    NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+}
+
+template
+static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
+        const float2 * const __restrict__ Q_f2,
+        const half2  * const __restrict__ K_h2,
+        const half2  * const __restrict__ V_h2,
+        const half2  * const __restrict__ mask_h2,
+        float2       * const __restrict__ dstk,
+        float2       * const __restrict__ dstk_fixup,
+        const float scale,
+        const float slope,
+        const float logit_softcap,
+        const int ne01,
+        const int ne02,
+        const int stride_Q1,
+        const int stride_Q2,
+        const int stride_KV,
+        const int stride_mask,
+        const int jt,
+        const int kb0_start,
+        const int kb0_stop) {
+#ifdef NEW_MMA_AVAILABLE
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    constexpr int ncols           = ncols1 * ncols2;
+    constexpr int cols_per_warp   = ntiles * tile_B::I;
+    constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
+    constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+
+    static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
+
+    static_assert(D           % nwarps == 0, "bad D");
+    static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter");
+
+    constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
+
+    // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements:
+    extern __shared__ half2 tile_K[];
+#ifdef CP_ASYNC_AVAILABLE
+    half2 * tile_V    = tile_K + KQ_per_iter*D2_padded;
+#else
+    half2 * tile_V    = tile_K;
+#endif // CP_ASYNC_AVAILABLE
+    half2 * tile_mask = tile_V + KQ_per_iter*D2_padded;
+
+    tile_B       Q_B[D/(2*tile_B::J) * ntiles];
+    tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles];
+
+    tile_B_16     * Q_B_16   = (tile_B_16     *) Q_B;
+    tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
+
+    float KQ_rowsum[cols_per_thread] = {0.0f};
+    float KQ_max[cols_per_thread];
+#pragma unroll
+    for (int col = 0; col < cols_per_thread; ++col) {
+        KQ_max[col] = -FLT_MAX/2.0f;
+    }
+
+    // Temporarily load Q data into tile_K, will be loaded into registers afterwards.
+    // The loading is done with decreasing granularity for D for better memory bandwidth.
+    const half2 scale_h2 = make_half2(scale, scale);
+#pragma unroll
+    for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+        const int k0_start  = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+        const int k0_stop   =                             D/2 - (D/2) % (1*stride_k);
+        const int stride_jc = WARP_SIZE / stride_k;
+
+        if (k0_start == k0_stop) {
+            continue;
+        }
+
+#pragma unroll
+        for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
+            const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+            if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
+                break;
+            }
+
+            const int j = jc / ncols2;
+            const int c = jc % ncols2;
+
+            if (jt*ncols1 + j < ne01) {
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
+                    tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
+                }
+            } else {
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f);
+                }
+            }
+        }
+    }
+
+    __syncthreads();
+
+    {
+        const int j0 = (threadIdx.y / np) * cols_per_warp;
+
+#pragma unroll
+        for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
+            if (ntiles == 1) {
+                load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
+            } else {
+#pragma unroll
+                for (int t = 0; t < ntiles/2; ++t) {
+                    load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
+                        tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded);
+                }
+            }
+        }
+    }
+
+    __syncthreads();
+
+    // Preload mask and K data for first iteration when using cp_async:
+#ifdef CP_ASYNC_AVAILABLE
+    if (ncols2 > 1 || mask_h2) {
+        flash_attn_ext_f16_load_mask(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask);
+    }
+    flash_attn_ext_f16_load_tile(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV);
+#endif // CP_ASYNC_AVAILABLE
+
+    // Iterate over ne11 == previous tokens:
+    for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
+        constexpr bool last_iter = false;
+        flash_attn_ext_f16_iter
+            (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+    }
+    { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
+        constexpr bool last_iter = true;
+        flash_attn_ext_f16_iter
+            (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
+             ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+    }
+
+    // With cp_async there is no __syncthreads at the end of the iter,
+    //     there can be a race condition on shared memory access for combining/writing back results.
+#ifdef CP_ASYNC_AVAILABLE
+    if (nwarps*cols_per_warp > KQ_per_iter) {
+        __syncthreads();
+    }
+#endif // CP_ASYNC_AVAILABLE
+
+    // Finally, sum up partial KQ rowsums.
+    // The partial sums are spread across 8/4 threads each, does not need full reduce.
+    {
+        constexpr int offset_first = ntiles == 1 ? 16 : 2;
+        constexpr int offset_last  = ntiles == 1 ?  4 : 1;
+#pragma unroll
+        for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+            for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
+                KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
+            }
+        }
+    }
+
+    // Write VKQ accumulators to shared memory in column-major format.
+    // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
+    // Also for np > 1 the combination is done via these values in shared memory.
+    if (ntiles == 1) {
+        const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
+#pragma unroll
+        for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
+            const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
+
+#pragma unroll
+            for (int l = 0; l < tile_B::ne; ++l) {
+                const int k = k0 + tile_B::get_j(l);
+
+                tile_K[jc_cwd*D2_padded + k] = B.x[l];
+            }
+        }
+    } else {
+#pragma unroll
+        for (int t = 0; t < ntiles/2; ++t) {
+            const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
+#pragma unroll
+            for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) {
+#pragma unroll
+                for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
+                    const int j = j0 + tile_C_VKQ_16::get_i(l);
+                    const int k = k0 + tile_C_VKQ_16::get_j(l);
+
+                    tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
+                }
+            }
+        }
+    }
+
+    if constexpr (ntiles == 1) {
+        const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
+        const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
+        const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
+
+        if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
+            // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+            ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+        }
+
+        __syncthreads();
+
+        if (np == 1) {
+            // No combination is needed, the meta data can be directly written from registers to VRAM.
+            if (needs_fixup && threadIdx.x < tile_B::I) {
+                float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
+            if (is_fixup && threadIdx.x < tile_B::I) {
+                float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
+        }
+    } else {
+        static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
+        const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
+            + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
+            + tile_C_VKQ_16::get_i(threadIdx.x % 4);
+        const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
+
+        if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
+            // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+            ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
+        }
+
+        __syncthreads();
+
+        if (np == 1) {
+            // No combination is needed, the meta data can be directly written from registers to VRAM.
+            if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+                float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
+            if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
+                float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+                dstk_fixup_meta[jc_cwm] = KQ_cmr;
+            }
+        }
+    }
+
+    static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
+    if (np > 1 && threadIdx.y % np == 0) {
+        // Combine the meta data for parallel warps via shared memory.
+        // Warps with threadIdx.y % np != 0 must NOT return early.
+        // All threads must return simultaneously to avoid race conditions with work on the next tile.
+
+        constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
+
+        const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
+        float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4;
+        float2 meta[nmeta];
+#pragma unroll
+        for (int imeta = 0; imeta < nmeta; ++imeta) {
+            meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2];
+        }
+
+        float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
+#pragma unroll
+        for (int imeta = 1; imeta < nmeta; ++imeta) {
+            KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
+        }
+#pragma unroll
+        for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
+            if (offset >= WARP_SIZE) {
+                continue;
+            }
+            KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
+        }
+
+        float KQ_cms[nmeta]; // KQ combine max scale per warp.
+#pragma unroll
+        for (int imeta = 0; imeta < nmeta; ++imeta) {
+            KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
+        }
+
+        float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
+#pragma unroll
+        for (int imeta = 1; imeta < nmeta; ++imeta) {
+            KQ_crs += KQ_cms[imeta]*meta[imeta].y;
+        }
+#pragma unroll
+        for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
+            if (offset >= WARP_SIZE) {
+                continue;
+            }
+            KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
+        }
+
+        // Write back combined meta data:
+#pragma unroll
+        for (int imeta = 0; imeta < nmeta; ++imeta) {
+            if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
+                // Combined KQ max scale + rowsum.
+                meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs);
+            }
+        }
+
+        // Combined KQ max + rowsum.
+        static_assert(cols_per_warp <= WARP_SIZE);
+        if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+            float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+            dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+        }
+        if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+            float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+            dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+        }
+    }
+
+    if (np > 1) {
+        __syncthreads();
+    }
+
+    if (np == 1 || threadIdx.y % np == 0) {
+        // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
+        // The values after that are for the partial results of the individual blocks.
+        float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
+
+#pragma unroll
+        for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+            const int k0_start  = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
+            const int k0_stop   =                             D/2 - (D/2) % (1*stride_k);
+            const int stride_jc = WARP_SIZE / stride_k;
+
+            if (k0_start == k0_stop) {
+                continue;
+            }
+
+#pragma unroll
+            for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
+                const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+                if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
+                    break;
+                }
+
+                const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
+
+                const int j_dst = jc_dst / ncols2;
+                const int c_dst = jc_dst % ncols2;
+
+                if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
+                    continue;
+                }
+
+                const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2;
+#pragma unroll
+                for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+                    const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+                    float2 dstk_val = make_float2(0.0f, 0.0f);
+#pragma unroll
+                    for (int ip = 0; ip < np; ++ip) {
+                        const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0];
+                        const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]);
+                        dstk_val.x += dstk_val_add.x*KQ_crs;
+                        dstk_val.y += dstk_val_add.y*KQ_crs;
+                    }
+
+                    if (!needs_fixup && !is_fixup) {
+                        const float KQ_rowsum_j = meta_j[1];
+                        dstk_val.x /= KQ_rowsum_j;
+                        dstk_val.y /= KQ_rowsum_j;
+                    }
+
+                    if (is_fixup) {
+                        dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val;
+                    } else {
+                        dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val;
+                    }
+                }
+            }
+        }
+    }
+
+    if (np > 1) {
+        __syncthreads();
+    }
+#else
+    NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+}
+
+template
+__launch_bounds__(nwarps*WARP_SIZE, 2)
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
+
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter");
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+
+    const int stride_Q1   = nb01 / sizeof(float2);
+    const int stride_Q2   = nb02 / sizeof(float2);
+    const int stride_KV   = nb11 / sizeof(half2);
+    const int stride_mask = nb31 / sizeof(half2);
+
+    const int iter_k = ne11 / FATTN_KQ_STRIDE;
+    const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+
+    constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice.
+
+    // kbc == k block continuous, current index in continuous ijk space.
+    int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+
+    // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
+    // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
+    // In the most general case >2 seams can fall into the same tile.
+
+    // kb0 == k start index when in the output tile.
+    int kb0_start = kbc % iter_k;
+    int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
+    while (kbc < kbc_stop && kb0_stop == iter_k) {
+        const int channel = kbc / (iter_k*iter_j);
+        const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+
+        const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
+        const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+        const half2  * V_h2    = (const half2  *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
+        const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+        float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * D/2);
+
+        const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+
+        const int kb0_start_kernel = kb0_start * kb_niter;
+        const int kb0_stop_kernel  = kb0_stop  * kb_niter;
+
+        constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+        if (kb0_start == 0) {
+            constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
+            flash_attn_ext_f16_process_tile
+                (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+                 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+        } else {
+            constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
+            flash_attn_ext_f16_process_tile
+                (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+                 ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+        }
+
+        kbc += iter_k;
+        kbc -= kbc % iter_k;
+
+        kb0_start = 0;
+        kb0_stop  = min(iter_k, kbc_stop - kbc);
+    }
+
+    if (kbc >= kbc_stop) {
+        return;
+    }
+
+    const int channel = kbc / (iter_k*iter_j);
+    const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
+
+    const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
+    const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
+    const half2  * V_h2    = (const half2  *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
+    const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
+    float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * D/2);
+
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
+
+    const int kb0_start_kernel = kb0_start * kb_niter;
+    const int kb0_stop_kernel  = kb0_stop  * kb_niter;
+
+    constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
+    constexpr bool needs_fixup = false;
+    flash_attn_ext_f16_process_tile
+        (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
+         ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+#else
+    NO_DEVICE_CODE;
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
+}
+
+template 
+void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    constexpr int ncols         = ncols1 * ncols2;
+    constexpr int KQ_per_iter   = D <= 128 && ncols1 <= 64 ? 64 : 32;
+    constexpr int nwarps        = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4;
+    constexpr int ntiles        = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4);
+    constexpr int cols_per_warp = ntiles * tile_B::I;
+
+    static_assert(D     %    tile_B::J  == 0, "bad D");
+    static_assert(ncols % cols_per_warp == 0, "bad ncols");
+
+    const ggml_tensor * KQV = dst;
+    const int id    = ggml_cuda_get_device();
+    const int cc    = ggml_cuda_info().devices[id].cc;
+
+    const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter;
+
+    const size_t nbytes_shared_KV      = KQ_shared_rows       * (D           + 8) * sizeof(half);
+    const size_t nbytes_shared_mask    = ncols1               * (KQ_per_iter + 8) * sizeof(half);
+    const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D           + 8) * sizeof(half);
+
+    const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine);
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    fattn_kernel_t fattn_kernel;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        fattn_kernel = flash_attn_ext_f16;
+    } else {
+        constexpr bool use_logit_softcap = true;
+        fattn_kernel = flash_attn_ext_f16;
+    }
+
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
+}
+
+
+#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2)                          \
+    template void ggml_cuda_flash_attn_ext_mma_f16_case                     \
+    (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
+    extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,   8);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,   8);
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  16);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  16);
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  32);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  32);
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128,  64);
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256,  64);
+
+// Kernels with ncols == 128 are only 4% faster due to register pressure.
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
+// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
index 4d314dac..ef3569fa 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
@@ -44,8 +44,13 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifdef FP16_AVAILABLE
+#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
+
     // Skip unused kernel variants for faster compilation:
+#ifdef FP16_MMA_AVAILABLE
+    NO_DEVICE_CODE;
+    return;
+#endif // FP16_MMA_AVAILABLE
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
         return;
@@ -280,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // FP16_AVAILABLE
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
 template 
@@ -288,16 +293,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
         case  64: {
-            constexpr int      D = 64;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 64;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         case 128: {
-            constexpr int      D = 128;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 128;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
index bb336044..04b69c83 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
@@ -44,11 +44,13 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifndef FLASH_ATTN_AVAILABLE
+#ifdef FLASH_ATTN_AVAILABLE
+
+    // Skip unused kernel variants for faster compilation:
+#ifdef FP16_MMA_AVAILABLE
     NO_DEVICE_CODE;
     return;
-#endif // FLASH_ATTN_AVAILABLE
-    // Skip unused kernel variants for faster compilation:
+#endif // FP16_MMA_AVAILABLE
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
         return;
@@ -280,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
             dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
+#else
+    NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
 }
 
 template 
@@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
         case  64: {
-            constexpr int      D = 64;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 64;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         case 128: {
-            constexpr int      D = 128;
-            constexpr int nwarps = 8;
+            constexpr int    D             = 128;
+            constexpr int    nwarps        = 8;
+            constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32;
-            launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
+            launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
index 34a2992c..b7686c1e 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
@@ -41,7 +41,8 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifdef FP16_AVAILABLE
+#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
+
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -294,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // FP16_AVAILABLE
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
 template 
@@ -303,7 +304,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
-    launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+    constexpr size_t nbytes_shared = 0;
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
 }
 
 template 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
index a28fc8b7..c1d2dd8d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
@@ -41,6 +41,8 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
+#ifdef FLASH_ATTN_AVAILABLE
+
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -276,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
     if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
+#else
+    NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
 }
 
 template 
@@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
-    launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
+    constexpr size_t nbytes_shared = 0;
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
 }
 
 template 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
new file mode 100644
index 00000000..8828652f
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -0,0 +1,648 @@
+// Old and deprecated WMMA FlashAttention implementation.
+// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
+// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-wmma-f16.cuh"
+
+#ifdef FP16_MMA_AVAILABLE
+#include 
+#endif // FP16_MMA_AVAILABLE
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*WARP_SIZE, 1)
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+static __global__ void flash_attn_ext_f16(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const float logit_softcap,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int nb21,
+        const int nb22,
+        const int nb23,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3) {
+#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
+    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
+
+    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+    constexpr int frag_m = ncols == 8 ? 32 : 16;
+    constexpr int frag_n = ncols == 8 ?  8 : 16;
+    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+    typedef nvcuda::wmma::fragment frag_a_K;
+    typedef nvcuda::wmma::fragment frag_a_V;
+    typedef nvcuda::wmma::fragment frag_b;
+    typedef nvcuda::wmma::fragment                      frag_c_KQ;
+    typedef nvcuda::wmma::fragment                          frag_c_VKQ;
+
+    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+    constexpr int D_padded = D + 8;
+    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
+    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
+    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
+    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
+    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
+
+    const int stride_Q  = nb01 / sizeof(float);
+    const int stride_KV = nb11 / sizeof(half);
+
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+    const half2 slope2 = make_half2(slopef, slopef);
+
+    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
+
+    frag_b Q_b[D/16][ncols/frag_n];
+
+    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+    constexpr int mem_KQ = ncols*kqs_padded*kqar;
+    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+    float * KQ_f = (float *) KQ;
+    half2 * KQ2 = (half2 *) KQ;
+
+    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
+    float       KQ_max_f[ncols/nwarps];
+    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_f[j] = -FLT_MAX/2.0f;
+    }
+
+    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+    half2       KQ_max_h2[ncols/nwarps];
+    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+    for (int j = 0; j < ncols/nwarps; ++j) {
+        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+    }
+
+    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+    half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                break;
+            }
+            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+        }
+    }
+
+    // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
+        }
+    }
+
+    __syncthreads();
+
+    // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+    for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+        }
+    }
+
+    __syncthreads();
+
+    // Iterate over ne11 == previous tokens:
+    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
+        // Calculate tile of KQ:
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+            frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+            }
+#pragma unroll
+            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+                frag_a_K K_a;
+                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+                }
+            }
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+        // Calculate softmax for each KQ column using the current max. value.
+        // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            if (std::is_same::value) {
+                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+
+                    if (use_logit_softcap) {
+                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+                    }
+                }
+
+                float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = warp_reduce_max(KQ_max_new);
+
+                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_f[j0/nwarps] = expf(diff);
+                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                    KQ_max_scale_f[j0/nwarps] = 0.0f;
+                }
+                KQ_max_f[j0/nwarps] = KQ_max_new;
+
+                float KQ_rowsum_add = 0.0f;
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
+                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+                    }
+                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
+                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+            } else {
+                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+
+                    if (use_logit_softcap) {
+                        // There is no dedicated tangens hyperbolicus function for half2.
+                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
+                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
+                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+
+                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
+                    }
+                }
+
+                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                }
+                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+                KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+                    const int k = k0 + threadIdx.x;
+
+                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
+                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
+                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
+                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+                }
+                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+
+                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+            }
+        }
+
+        __syncthreads();
+
+        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+                nvcuda::wmma::load_matrix_sync(
+                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+                    KQ + j0*(kqar*kqs_padded) + k,
+                    kqar*kqs_padded);
+            }
+        }
+
+        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j = 0; j < ncols/frag_n; ++j) {
+                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+            }
+
+#pragma unroll
+            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+                frag_a_V v_a;
+                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+                for (int j = 0; j < ncols/frag_n; ++j) {
+                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+                }
+            }
+        }
+
+        __syncthreads();
+
+        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+                nvcuda::wmma::store_matrix_sync(
+                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+                    D_padded, nvcuda::wmma::mem_col_major);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + threadIdx.y;
+
+            half2 VKQ_scale;
+            if (std::is_same::value) {
+                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+            } else {
+                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                const int i = i0 + threadIdx.x;
+                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+                    break;
+                }
+
+                half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+                for (int l = 0; l < VKQ_ratio; ++l) {
+                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+                }
+                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+            }
+        }
+
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        const int j_VKQ = j0 + threadIdx.y;
+        if (ic0 + j_VKQ >= ne01) {
+            return;
+        }
+        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
+
+        float KQ_rowsum_j;
+        if (std::is_same::value) {
+            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+        } else {
+            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+            const int i = i0 + threadIdx.x;
+            if (i0 + WARP_SIZE > D && i >= D) {
+                break;
+            }
+            float dst_val = VKQ[j_VKQ*D_padded + i];
+            if (parallel_blocks == 1) {
+                dst_val /= KQ_rowsum_j;
+            }
+            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
+        }
+
+        if (parallel_blocks == 1 || threadIdx.x != 0) {
+            continue;
+        }
+
+        float2 dst_meta_val;
+        if (std::is_same::value) {
+            dst_meta_val.x = KQ_max_f[j0/nwarps];
+        } else {
+            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+        }
+        dst_meta_val.y = KQ_rowsum_j;
+        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
+    }
+#else
+   NO_DEVICE_CODE;
+#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+}
+
+constexpr int get_max_power_of_2(int x) {
+    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
+
+template 
+void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    constexpr int nwarps = 4;
+
+    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
+    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (4*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 4;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
+        launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        return;
+    }
+    if (2*blocks_num_pb1 < 2*nsm) {
+        constexpr int parallel_blocks = 2;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
+        launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        return;
+    }
+    constexpr int parallel_blocks = 1;
+    fattn_kernel_t fattn_kernel;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    } else {
+        constexpr bool use_logit_softcap = true;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    }
+    launch_fattn(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+}
+
+void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+
+    if (prec != GGML_PREC_DEFAULT) {
+        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+            constexpr int cols_per_block = 16;
+            switch (Q->ne[0]) {
+                case 64:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+                    break;
+                case 80:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+                    break;
+                case 96:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+                    break;
+                case 112:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+                    break;
+                case 128:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                    break;
+                case 256:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
+                    break;
+                default:
+                    GGML_ABORT("fatal error");
+                    break;
+            }
+        } else {
+            constexpr int cols_per_block = 32;
+            switch (Q->ne[0]) {
+                case 64:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+                    break;
+                case 80:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+                    break;
+                case 96:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+                    break;
+                case 112:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+                    break;
+                case 128:
+                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                    break;
+                // case 256:
+                //     ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
+                //     break;
+                default:
+                    GGML_ABORT("fatal error");
+                    break;
+            }
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+        constexpr int cols_per_block = 8;
+        switch (Q->ne[0]) {
+            case 64:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+                break;
+            case 96:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+                break;
+            case 128:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+                break;
+            case 256:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+        return;
+    }
+
+    if (Q->ne[1] <= 32) {
+        constexpr int cols_per_block = 16;
+        switch (Q->ne[0]) {
+            case 64:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+                break;
+            case 80:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+                break;
+            case 96:
+                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+                break;
+            case 112:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+                break;
+            case 128:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+                break;
+            case 256:
+                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+                break;
+            default:
+                GGML_ABORT("fatal error");
+                break;
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 32;
+    switch (Q->ne[0]) {
+        case 64:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+            break;
+        case 80:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+            break;
+        case 96:
+            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+            break;
+        case 112:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+            break;
+        case 128:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+            break;
+        case 256:
+            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
index 860d0e6d..beeea95e 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
@@ -1,543 +1,3 @@
 #include "common.cuh"
-#include "fattn-common.cuh"
 
-#ifdef FP16_MMA_AVAILABLE
-#include 
-#endif // FP16_MMA_AVAILABLE
-
-// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
-template
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-static __global__ void flash_attn_ext_f16(
-        const char * __restrict__ Q,
-        const char * __restrict__ K,
-        const char * __restrict__ V,
-        const char * __restrict__ mask,
-        float      * __restrict__ dst,
-        float2     * __restrict__ dst_meta,
-        const float scale,
-        const float max_bias,
-        const float m0,
-        const float m1,
-        const uint32_t n_head_log2,
-        const float logit_softcap,
-        const int ne00,
-        const int ne01,
-        const int ne02,
-        const int ne03,
-        const int ne10,
-        const int ne11,
-        const int ne12,
-        const int ne13,
-        const int ne31,
-        const int nb31,
-        const int nb01,
-        const int nb02,
-        const int nb03,
-        const int nb11,
-        const int nb12,
-        const int nb13,
-        const int nb21,
-        const int nb22,
-        const int nb23,
-        const int ne0,
-        const int ne1,
-        const int ne2,
-        const int ne3) {
-#ifdef FP16_MMA_AVAILABLE
-    // Skip unused kernel variants for faster compilation:
-    if (use_logit_softcap && !(D == 128 || D == 256)) {
-        NO_DEVICE_CODE;
-        return;
-    }
-
-    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
-
-    const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
-    const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel.
-
-    static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
-    static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
-    constexpr int frag_m = ncols == 8 ? 32 : 16;
-    constexpr int frag_n = ncols == 8 ?  8 : 16;
-    static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
-    typedef nvcuda::wmma::fragment frag_a_K;
-    typedef nvcuda::wmma::fragment frag_a_V;
-    typedef nvcuda::wmma::fragment frag_b;
-    typedef nvcuda::wmma::fragment                      frag_c_KQ;
-    typedef nvcuda::wmma::fragment                          frag_c_VKQ;
-
-    constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
-    constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
-    static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
-
-    // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
-    constexpr int D_padded = D + 8;
-    constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
-    constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
-
-    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-    const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0);
-    const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio));
-    const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
-    const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0;
-    const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2);
-
-    const int stride_Q  = nb01 / sizeof(float);
-    const int stride_KV = nb11 / sizeof(half);
-
-    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
-    const half  slopeh = __float2half(slopef);
-    const half2 slope2 = make_half2(slopef, slopef);
-
-    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
-
-    frag_b Q_b[D/16][ncols/frag_n];
-
-    // A single buffer for temporarily holding tiles of KQ and VKQ parts:
-    constexpr int mem_KQ = ncols*kqs_padded*kqar;
-    constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
-    __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
-    float * KQ_f = (float *) KQ;
-    half2 * KQ2 = (half2 *) KQ;
-
-    float    KQ_rowsum_f[ncols/nwarps] = {0.0f};
-    float       KQ_max_f[ncols/nwarps];
-    float KQ_max_scale_f[ncols/nwarps] = {0.0f};
-
-#pragma unroll
-    for (int j = 0; j < ncols/nwarps; ++j) {
-        KQ_max_f[j] = -FLT_MAX/2.0f;
-    }
-
-    half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
-    half2       KQ_max_h2[ncols/nwarps];
-    half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
-
-#pragma unroll
-    for (int j = 0; j < ncols/nwarps; ++j) {
-        KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
-    }
-
-    __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
-    half2 * VKQ2 = (half2 *) VKQ;
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
-#pragma unroll
-        for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D/2 && i >= D/2) {
-                break;
-            }
-            VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
-        }
-    }
-
-    // Convert Q to half and apply scale, temporarily store in KQ:
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
-#pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
-                break;
-            }
-            KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
-        }
-    }
-
-    __syncthreads();
-
-    // Load Q into tensor core fragments/registers since it will be used frequently:
-#pragma unroll
-    for (int i0 = 0; i0 < D; i0 += 16) {
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-            nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
-        }
-    }
-
-    __syncthreads();
-
-    // Iterate over ne11 == previous tokens:
-    for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
-        // Calculate tile of KQ:
-#pragma unroll
-        for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
-            frag_c_KQ KQ_c[ncols/frag_n];
-#pragma unroll
-            for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
-            }
-#pragma unroll
-            for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
-                frag_a_K K_a;
-                nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
-#pragma unroll
-                for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
-                }
-            }
-#pragma unroll
-            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
-            }
-        }
-
-        __syncthreads();
-
-        // Calculate softmax for each KQ column using the current max. value.
-        // The divisor is stored in KQ_rowsum and will be applied at the end.
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-            const int j = j0 + threadIdx.y;
-
-            if (std::is_same::value) {
-                float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
-
-                    if (use_logit_softcap) {
-                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
-                    }
-                }
-
-                float KQ_max_new = KQ_max_f[j0/nwarps];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
-                    KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
-                }
-                KQ_max_new = warp_reduce_max(KQ_max_new);
-
-                const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
-                KQ_max_scale_f[j0/nwarps] = expf(diff);
-                if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                    KQ_max_scale_f[j0/nwarps] = 0.0f;
-                }
-                KQ_max_f[j0/nwarps] = KQ_max_new;
-
-                float KQ_rowsum_add = 0.0f;
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
-                    KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
-                    if (diff <= SOFTMAX_FTZ_THRESHOLD) {
-                        KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
-                    }
-                    KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
-                    KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
-                }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
-
-                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
-                KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
-            } else {
-                half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
-
-                    if (use_logit_softcap) {
-                        // There is no dedicated tangens hyperbolicus function for half2.
-                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
-                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
-                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
-
-                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
-                    }
-                }
-
-                half2 KQ_max_new = KQ_max_h2[j0/nwarps];
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
-                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
-                }
-                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
-                const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
-                KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
-                const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
-                *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
-                KQ_max_h2[j0/nwarps] = KQ_max_new;
-
-                half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
-#pragma unroll
-                for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
-                    const int k = k0 + threadIdx.x;
-
-                    const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
-                    KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
-                    const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
-                    *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
-                    KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
-                    KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
-                }
-                KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
-
-                // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
-                KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
-            }
-        }
-
-        __syncthreads();
-
-        frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-#pragma unroll
-            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
-                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
-                nvcuda::wmma::load_matrix_sync(
-                    KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
-                    KQ + j0*(kqar*kqs_padded) + k,
-                    kqar*kqs_padded);
-            }
-        }
-
-        frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
-#pragma unroll
-        for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
-#pragma unroll
-            for (int j = 0; j < ncols/frag_n; ++j) {
-                nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
-            }
-
-#pragma unroll
-            for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
-                const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
-
-                frag_a_V v_a;
-                nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
-#pragma unroll
-                for (int j = 0; j < ncols/frag_n; ++j) {
-                    nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
-                }
-            }
-        }
-
-        __syncthreads();
-
-        const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
-#pragma unroll
-        for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
-#pragma unroll
-            for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-                nvcuda::wmma::store_matrix_sync(
-                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
-                    VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
-                    D_padded, nvcuda::wmma::mem_col_major);
-            }
-        }
-
-        __syncthreads();
-
-#pragma unroll
-        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-            const int j = j0 + threadIdx.y;
-
-            half2 VKQ_scale;
-            if (std::is_same::value) {
-                VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
-            } else {
-                VKQ_scale = KQ_max_scale_h2[j0/nwarps];
-            }
-
-#pragma unroll
-            for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-                const int i = i0 + threadIdx.x;
-                if (i0 + WARP_SIZE > D/2 && i >= D/2) {
-                    break;
-                }
-
-                half2 VKQ_add = make_half2(0.0f, 0.0f);
-#pragma unroll
-                for (int l = 0; l < VKQ_ratio; ++l) {
-                    VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
-                }
-                VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
-            }
-        }
-
-        __syncthreads();
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
-        const int j_VKQ = j0 + threadIdx.y;
-        if (ic0 + j_VKQ >= ne01) {
-            return;
-        }
-        const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
-
-        float KQ_rowsum_j;
-        if (std::is_same::value) {
-            KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
-        } else {
-            KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
-        }
-
-#pragma unroll
-        for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
-            const int i = i0 + threadIdx.x;
-            if (i0 + WARP_SIZE > D && i >= D) {
-                break;
-            }
-            float dst_val = VKQ[j_VKQ*D_padded + i];
-            if (parallel_blocks == 1) {
-                dst_val /= KQ_rowsum_j;
-            }
-            dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
-        }
-
-        if (parallel_blocks == 1 || threadIdx.x != 0) {
-            continue;
-        }
-
-        float2 dst_meta_val;
-        if (std::is_same::value) {
-            dst_meta_val.x = KQ_max_f[j0/nwarps];
-        } else {
-            dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
-        }
-        dst_meta_val.y = KQ_rowsum_j;
-        dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
-    }
-#else
-   NO_DEVICE_CODE;
-#endif // FP16_MMA_AVAILABLE
-}
-
-constexpr int get_max_power_of_2(int x) {
-    return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
-}
-
-static_assert(get_max_power_of_2(1) == 1, "Test failed.");
-static_assert(get_max_power_of_2(2) == 2, "Test failed.");
-static_assert(get_max_power_of_2(4) == 4, "Test failed.");
-static_assert(get_max_power_of_2(6) == 2, "Test failed.");
-
-// Number of VKQ rows calculated in parallel:
-constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
-    return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
-}
-
-static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed.");
-static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
-static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed.");
-static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed.");
-static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
-static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
-static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
-
-template 
-void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
-
-    constexpr int nwarps = 4;
-
-    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
-    const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
-    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
-
-    float logit_softcap;
-    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
-
-    if (4*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 4;
-        fattn_kernel_t fattn_kernel;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        } else {
-            constexpr bool use_logit_softcap = true;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        }
-        launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
-        return;
-    }
-    if (2*blocks_num_pb1 < 2*nsm) {
-        constexpr int parallel_blocks = 2;
-        fattn_kernel_t fattn_kernel;
-        if (logit_softcap == 0.0f) {
-            constexpr bool use_logit_softcap = false;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        } else {
-            constexpr bool use_logit_softcap = true;
-            fattn_kernel = flash_attn_ext_f16<
-                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-        }
-        launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
-        return;
-    }
-    constexpr int parallel_blocks = 1;
-    fattn_kernel_t fattn_kernel;
-    if (logit_softcap == 0.0f) {
-        constexpr bool use_logit_softcap = false;
-        fattn_kernel = flash_attn_ext_f16<
-            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-    } else {
-        constexpr bool use_logit_softcap = true;
-        fattn_kernel = flash_attn_ext_f16<
-            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
-    }
-    launch_fattn(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
-}
-
-#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t)                         \
-    template void ggml_cuda_flash_attn_ext_wmma_f16_case                              \
-    (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
-extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
-extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
-// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64,  8, half);
-extern DECL_FATTN_WMMA_F16_CASE( 96,  8, half);
-extern DECL_FATTN_WMMA_F16_CASE(128,  8, half);
-extern DECL_FATTN_WMMA_F16_CASE(256,  8, half);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
-extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
-
-extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
-extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
+void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
index 0b26b0f8..b1becccb 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
@@ -1,5 +1,6 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
+#include "fattn-mma-f16.cuh"
 #include "fattn-tile-f16.cuh"
 #include "fattn-tile-f32.cuh"
 #include "fattn-vec-f16.cuh"
@@ -7,144 +8,89 @@
 #include "fattn-wmma-f16.cuh"
 #include "fattn.cuh"
 
-#include 
+template 
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
 
-static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
-
-    const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
-
-    if (prec != GGML_PREC_DEFAULT) {
-        if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
-            constexpr int cols_per_block = 16;
-            switch (Q->ne[0]) {
-                case 64:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
-                    break;
-                case 80:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
-                    break;
-                case 96:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
-                    break;
-                case 112:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
-                    break;
-                case 128:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
-                    break;
-                case 256:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
-                    break;
-                default:
-                    GGML_ABORT("fatal error");
-                    break;
-            }
-        } else {
-            constexpr int cols_per_block = 32;
-            switch (Q->ne[0]) {
-                case 64:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
-                    break;
-                case 80:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
-                    break;
-                case 96:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
-                    break;
-                case 112:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
-                    break;
-                case 128:
-                    ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
-                    break;
-                // case 256:
-                //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
-                //     break;
-                default:
-                    GGML_ABORT("fatal error");
-                    break;
-            }
-        }
+    if (Q->ne[1] <= 8/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
         return;
     }
 
-    if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
-        constexpr int cols_per_block = 8;
-        switch (Q->ne[0]) {
-            case 64:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
-                break;
-            case 96:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
-                break;
-            case 128:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
-                break;
-            case 256:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
-                break;
-            default:
-                GGML_ABORT("fatal error");
-                break;
-        }
+    if (Q->ne[1] <= 16/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
         return;
     }
 
-    if (Q->ne[1] <= 32) {
-        constexpr int cols_per_block = 16;
-        switch (Q->ne[0]) {
-            case 64:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
-                break;
-            case 80:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
-                break;
-            case 96:
-                ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
-                break;
-            case 112:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
-                break;
-            case 128:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
-                break;
-            case 256:
-                ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
-                break;
-            default:
-                GGML_ABORT("fatal error");
-                break;
-        }
+    if (Q->ne[1] <= 32/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
         return;
     }
 
-    constexpr int cols_per_block = 32;
+    ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst);
+}
+
+template 
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+
     switch (Q->ne[0]) {
         case 64:
-            ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
             break;
         case 80:
-            ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
             break;
         case 96:
-            ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
             break;
         case 112:
-            ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
             break;
         case 128:
-            ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
             break;
         case 256:
-            ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
             break;
         default:
             GGML_ABORT("fatal error");
             break;
     }
 }
+
+static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * mask = dst->src[3];
+
+    float max_bias = 0.0f;
+    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+    const float use_gqa_opt = mask && max_bias == 0.0f;
+
+    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+
+    if (use_gqa_opt && gqa_ratio % 8 == 0) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
+        return;
+    }
+
+    if (use_gqa_opt && gqa_ratio == 4) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
+        return;
+    }
+
+    if (use_gqa_opt && gqa_ratio == 2) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
+        return;
+    }
+
+    ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
+}
+
 #define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \
     if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \
         ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \
@@ -296,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
 }
 
 void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * V    = dst->src[2];
+    const ggml_tensor * mask = dst->src[3];
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -323,15 +272,26 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     }
 
     if (!fp16_mma_available(cc)) {
-        if (Q->ne[1] <= 8) {
-            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+        if (prec == GGML_PREC_DEFAULT) {
+            if (Q->ne[1] <= 8) {
+                ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+            } else {
+                ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+            }
         } else {
-            ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+            if (Q->ne[1] <= 8) {
+                ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+            } else {
+                ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
+            }
         }
         return;
     }
 
-    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+    const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
+        K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
+    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             return;
@@ -341,5 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         }
     }
 
-    ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+    // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
+    if (cc == GGML_CUDA_CC_VOLTA) {
+        ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+        return;
+    }
+
+    ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu b/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
index 4c370323..4cef53a9 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cu
@@ -3,15 +3,15 @@
 
 template
 static __global__ void k_get_rows(
-            const void * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+        const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+        const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
+        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
+        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
+        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
     const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
-    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i10 =  blockDim.y*blockIdx.y + threadIdx.y;
     const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
     const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
 
@@ -22,10 +22,10 @@ static __global__ void k_get_rows(
     const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+    const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
-    const int ib = i00/qk; // block index
-    const int iqs = (i00%qk)/qr; // quant index
+    const int ib   =  i00/qk;      // block index
+    const int iqs  = (i00%qk)/qr;  // quant index
     const int iybs = i00 - i00%qk; // dst block start index
     const int y_offset = qr == 1 ? 1 : qk/2;
 
@@ -39,15 +39,15 @@ static __global__ void k_get_rows(
 
 template
 static __global__ void k_get_rows_float(
-            const src0_t * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+        const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+        const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
+        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+        /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
+        /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
+        const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
-    const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
-    const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+    const int i00 =  blockIdx.x*blockDim.x + threadIdx.x;
+    const int i10 =  blockDim.y*blockIdx.y + threadIdx.y;
     const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
     const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
 
@@ -58,14 +58,38 @@ static __global__ void k_get_rows_float(
     const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
     dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+    const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
 
     dst_row[i00] = src0_row[i00];
 }
 
+template
+static __global__ void k_get_rows_back_float(
+        const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
+    const int col = blockIdx.x*blockDim.x + threadIdx.x;
+
+    if (col >= ncols) {
+        return;
+    }
+
+    const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
+
+    float sum = 0.0f;
+
+    for (int64_t i = 0; i < nrows_grad; ++i) {
+        if (rows[i] != dst_row) {
+            continue;
+        }
+        sum += grad[i*ncols + col];
+    }
+
+    dst[dst_row*ncols + col] = sum;
+}
+
 template
-static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-                            const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+static void get_rows_cuda(
+        const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+        const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
@@ -87,22 +111,25 @@ static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, gg
     GGML_ASSERT(ne00 % 2 == 0);
 
     k_get_rows<<>>(
-            src0_dd, src1_dd, dst_dd,
-            ne00, /*ne01, ne02, ne03,*/
-            /*ne10, ne11,*/ ne12, /*ne13,*/
-            /* s0,*/ s1, s2, s3,
-            /* nb00,*/ nb01, nb02, nb03,
-            s10, s11, s12/*, s13*/);
+        src0_dd, src1_dd, dst_dd,
+        ne00, /*ne01, ne02, ne03,*/
+        /*ne10, ne11,*/ ne12, /*ne13,*/
+        /* s0,*/ s1, s2, s3,
+        /* nb00,*/ nb01, nb02, nb03,
+        s10, s11, s12/*, s13*/);
 
     GGML_UNUSED(dst);
 }
 
 template
-static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
-                                const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+static void get_rows_cuda_float(
+        const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+        const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
+    GGML_ASSERT(ne13 == 1);
+
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
     const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
     const dim3 block_nums(block_num_x, ne10, ne11*ne12);
@@ -119,12 +146,12 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
     //const size_t s13 = nb13 / ggml_element_size(src1);
 
     k_get_rows_float<<>>(
-            src0_dd, src1_dd, dst_dd,
-            ne00, /*ne01, ne02, ne03,*/
-            /*ne10, ne11,*/ ne12, /*ne13,*/
-            /* s0,*/ s1, s2, s3,
-            /* nb00,*/ nb01, nb02, nb03,
-            s10, s11, s12/*, s13*/);
+        src0_dd, src1_dd, dst_dd,
+        ne00, /*ne01, ne02, ne03,*/
+        /*ne10, ne11,*/ ne12, /*ne13,*/
+        /* s0,*/ s1, s2, s3,
+        /* nb00,*/ nb01, nb02, nb03,
+        s10, s11, s12/*, s13*/);
 
     GGML_UNUSED(dst);
 }
@@ -132,42 +159,41 @@ static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * sr
 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
-    const float * src0_d = (const float *)src0->data;
-    const float * src1_d = (const float *)src1->data;
-    float * dst_d = (float *)dst->data;
+
+    const void    * src0_d = (const void    *) src0->data;
+    const int32_t * src1_d = (const int32_t *) src1->data;
+    float         * dst_d  = (float         *) dst->data;
+
     cudaStream_t stream = ctx.stream();
 
-
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
     GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
     GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
-    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
-
-    const int32_t * src1_i32 = (const int32_t *) src1_d;
+    GGML_ASSERT(dst->nb[0]  == ggml_type_size(dst->type));
 
     switch (src0->type) {
         case GGML_TYPE_F16:
-            get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda_float(src0, src1, dst, (const half *) src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_F32:
-            get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda_float(src0, src1, dst, (const float *) src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q4_0:
-            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q4_1:
-            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q5_0:
-            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q5_1:
-            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         case GGML_TYPE_Q8_0:
-            get_rows_cuda(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            get_rows_cuda(src0, src1, dst, src0_d, src1_d, dst_d, stream);
             break;
         default:
             // TODO: k-quants
@@ -175,3 +201,34 @@ void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             break;
     }
 }
+
+void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
+    const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const float   * src0_d = (const float   *) src0->data;
+    const int32_t * src1_d = (const int32_t *) src1->data;
+    float         * dst_d  = (float         *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    GGML_ASSERT(ggml_is_contiguous(src1));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_ASSERT(ne02*ne03 == 1);
+    GGML_ASSERT(ne12*ne13 == 1);
+    GGML_ASSERT(ne2*ne3 == 1);
+
+    const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
+    const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
+    const dim3 block_nums(block_num_x, ne1, 1);
+
+    k_get_rows_back_float<<>>(src0_d, src1_d, dst_d, ne00, ne10);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cuh
index bbf13023..a1ca643f 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/getrows.cuh
@@ -1,5 +1,8 @@
 #include "common.cuh"
 
 #define CUDA_GET_ROWS_BLOCK_SIZE 256
+#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
 
 void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
index 9286f866..1adf08fa 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -37,10 +37,13 @@
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv6.cuh"
+#include "ggml-cuda/gla.cuh"
+#include "ggml.h"
 
 #include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
@@ -61,7 +64,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
 [[noreturn]]
 void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
     int id = -1; // in case cudaGetDevice fails
-    cudaGetDevice(&id);
+    (void)cudaGetDevice(&id);
 
     GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
     GGML_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func, file, line);
@@ -118,12 +121,78 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
 #endif
 }
 
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+static int ggml_cuda_parse_id(char devName[]) {
+    // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
+    // these values are not stable so this is susceptible to breakage
+    // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp
+    int archMajor = 0x0;
+    int archMinor = 0x0;
+    int archNum = GGML_CUDA_CC_OFFSET_AMD;
+    int archLen = strlen(devName);
+    char archName[archLen + 1];
+
+    // strip leading 'gfx' while copying into our buffer
+    if (archLen > 3) {
+        strcpy(archName, &devName[3]);
+        archLen -= 3;
+    }
+
+    // trim trailing :xnack- or :sramecc- statuses
+    archLen = strcspn(archName, ":");
+    archName[archLen] = '\0';
+
+    // tease out the version information
+    if (archLen > 8) {
+        // versions labeled generic use '-' as delimiter
+        // strip the trailing "-generic" then iterate through what remains
+        if ((strstr(archName, "-generic"))) {
+            archName[archLen - 8] = '\0';
+            char * pch;
+            if ((pch = strtok(archName, "-"))) {
+                archMajor = (int)strtoul(pch, 0, 16);
+                if ((pch = strtok(NULL, "-"))) {
+                    archMinor = 0x10 * (int)strtoul(pch, 0, 16);
+                }
+            }
+        }
+    } else if (archLen >= 3) {
+        // last two digits should be the minor * 0x10 + stepping
+        archMinor = (int)strtoul(&archName[archLen - 2], 0, 16);
+        archName[archLen - 2] = '\0';
+
+        // only the major version remains
+        archMajor = (int)strtoul(archName, 0, 16);
+    }
+    archNum += archMajor * 0x100;
+    archNum += archMinor;
+    return archNum;
+}
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
+
 static ggml_cuda_device_info ggml_cuda_init() {
 #ifdef __HIP_PLATFORM_AMD__
     // Workaround for a rocBLAS bug when using multiple graphics cards:
     // https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
-    rocblas_initialize();
-    CUDA_CHECK(cudaDeviceSynchronize());
+    {
+        int major_version = 0;
+        size_t version_length = 0;
+        if (rocblas_get_version_string_size(&version_length) == rocblas_status_success) {
+            std::vector version(version_length+1, '\0');
+            if (rocblas_get_version_string(version.data(), version.size()) == rocblas_status_success) {
+                version.resize(::strlen(version.data()));
+                int parsed_value = 0;
+                if (std::from_chars(version.data(), version.data() + version.size(), parsed_value).ec == std::errc()) {
+                    major_version = parsed_value;
+                }
+            }
+        }
+        if (major_version < 4) {
+            GGML_LOG_DEBUG(GGML_CUDA_NAME " calling rocblas_initialize as a workaround for a rocBLAS bug\n");
+            rocblas_initialize();
+            CUDA_CHECK(cudaDeviceSynchronize());
+        }
+    }
 #endif
 
     ggml_cuda_device_info info = {};
@@ -151,7 +220,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
     for (int id = 0; id < info.device_count; ++id) {
         int device_vmm = 0;
 
-#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#if defined(GGML_USE_VMM)
         CUdevice device;
         CU_CHECK(cuDeviceGet(&device, id));
         CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@@ -163,24 +232,46 @@ static ggml_cuda_device_info ggml_cuda_init() {
             alloc_prop.location.id = id;
             CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
         }
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#endif // defined(GGML_USE_VMM)
         info.devices[id].vmm = !!device_vmm;
 
         cudaDeviceProp prop;
         CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
-        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 
         info.default_tensor_split[id] = total_vram;
         total_vram += prop.totalGlobalMem;
 
-        info.devices[id].nsm   = prop.multiProcessorCount;
-        info.devices[id].smpb  = prop.sharedMemPerBlock;
+        info.devices[id].nsm       = prop.multiProcessorCount;
+        info.devices[id].smpb      = prop.sharedMemPerBlock;
+        info.devices[id].warp_size = prop.warpSize;
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
         info.devices[id].smpbo = prop.sharedMemPerBlock;
-        info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
+
+        info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
+        if ((info.devices[id].cc & 0xff00) == 0x0) {
+            GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s  cc %d.%d\n",
+                            id, prop.name, prop.gcnArchName, prop.major, prop.minor);
+
+            // Fallback to prop.major and prop.minor
+            if (prop.major > 0) {
+                info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100;
+                info.devices[id].cc += prop.minor * 0x10;
+            }
+        }
+        GGML_LOG_INFO("  Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
+                      id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
+                      device_vmm ? "yes" : "no", prop.warpSize);
+#elif defined(GGML_USE_MUSA)
+        // TODO: refine the .cc to reflect MUSA's actual CC capabilities
+        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
+        info.devices[id].cc = 100*prop.major + 10*prop.minor;
+        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
+                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 #else
         info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
+        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
+                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
     }
 
@@ -299,7 +390,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
 };
 
 // pool with virtual memory
-#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#if defined(GGML_USE_VMM)
 struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
     static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
 
@@ -308,6 +399,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
     size_t pool_used = 0;
     size_t pool_size = 0;
     size_t granularity;
+#if defined(GGML_USE_HIP)
+    std::vector> mappings;
+#endif
 
     explicit ggml_cuda_pool_vmm(int device) :
         device(device),
@@ -316,7 +410,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
 
     ~ggml_cuda_pool_vmm() {
         if (pool_addr != 0) {
+#if defined(GGML_USE_HIP)
+            // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
+            for (std::pair & mapping : mappings) {
+                CU_CHECK(cuMemUnmap(mapping.first, mapping.second));
+            }
+#else
             CU_CHECK(cuMemUnmap(pool_addr, pool_size));
+#endif
             CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
         }
     }
@@ -349,7 +450,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
             }
 
             // map at the end of the pool
-            CU_CHECK(cuMemMap(pool_addr + pool_size, reserve_size, 0, handle, 0));
+            CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);
+            CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0));
+#if defined(GGML_USE_HIP)
+            mappings.push_back({start_ptr, reserve_size});
+#endif
 
             // the memory allocation handle is no longer needed after mapping
             CU_CHECK(cuMemRelease(handle));
@@ -359,7 +464,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
             access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
             access.location.id = device;
             access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
-            CU_CHECK(cuMemSetAccess(pool_addr + pool_size, reserve_size, &access, 1));
+            CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
 
             // add to the pool
             pool_size += reserve_size;
@@ -371,7 +476,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
 
         GGML_ASSERT(pool_addr != 0);
 
-        void * ptr = (void *) (pool_addr + pool_used);
+        void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used));
         *actual_size = size;
         pool_used += size;
 
@@ -390,17 +495,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
         pool_used -= size;
 
         // all deallocations must be in reverse order of the allocations
-        GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
+        GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
     }
 };
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#endif // defined(GGML_USE_VMM)
 
 std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) {
-#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#if defined(GGML_USE_VMM)
     if (ggml_cuda_info().devices[device].vmm) {
         return std::unique_ptr(new ggml_cuda_pool_vmm(device));
     }
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+#endif // defined(GGML_USE_VMM)
     return std::unique_ptr(new ggml_cuda_pool_leg(device));
 }
 
@@ -547,7 +652,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
     cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
     if (err != cudaSuccess) {
         // clear the error
-        cudaGetLastError();
+        (void)cudaGetLastError();
         GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
         return nullptr;
     }
@@ -962,7 +1067,7 @@ static void * ggml_cuda_host_malloc(size_t size) {
     cudaError_t err = cudaMallocHost((void **) &ptr, size);
     if (err != cudaSuccess) {
         // clear the error
-        cudaGetLastError();
+        (void)cudaGetLastError();
         GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
                            size / 1024.0 / 1024.0, cudaGetErrorString(err));
         return nullptr;
@@ -1082,7 +1187,9 @@ static void ggml_cuda_op_mul_mat_cublas(
 
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
-    if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
+    const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
+
+    if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
         ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id));
         if (src0->type != GGML_TYPE_F16) {
@@ -1103,28 +1210,38 @@ static void ggml_cuda_op_mul_mat_cublas(
             to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
         }
         const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
-        ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols);
-
-        const half alpha_f16 = 1.0f;
-        const half beta_f16 = 0.0f;
-
-        cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
-        if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
-            cu_compute_type = CUBLAS_COMPUTE_32F;
-        }
 
         CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
-        CUBLAS_CHECK(
-            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
-                    row_diff, src1_ncols, ne10,
-                    &alpha_f16, src0_ptr,       CUDA_R_16F, ne00,
-                                src1_ptr,       CUDA_R_16F, ne10,
-                    &beta_f16,   dst_f16.get(), CUDA_R_16F, ldc,
-                    cu_compute_type,
-                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
-        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
-        to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+        if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
+            const float alpha = 1.0f;
+            const float beta = 0.0f;
+            CUBLAS_CHECK(
+                cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+                        row_diff, src1_ncols, ne10,
+                        &alpha, src0_ptr,  CUDA_R_16F, ne00,
+                                src1_ptr,  CUDA_R_16F, ne10,
+                        &beta,   dst_dd_i, CUDA_R_32F, ldc,
+                        CUBLAS_COMPUTE_32F,
+                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+        } else {
+            ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols);
+
+            const half alpha_f16 = 1.0f;
+            const half beta_f16 = 0.0f;
+
+            CUBLAS_CHECK(
+                cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+                        row_diff, src1_ncols, ne10,
+                        &alpha_f16, src0_ptr,      CUDA_R_16F, ne00,
+                                    src1_ptr,      CUDA_R_16F, ne10,
+                        &beta_f16,  dst_f16.get(), CUDA_R_16F, ldc,
+                        CUBLAS_COMPUTE_16F,
+                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+            to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+        }
     } else {
         ggml_cuda_pool_alloc src0_ddq_as_f32(ctx.pool(id));
         ggml_cuda_pool_alloc src1_ddq_as_f32(ctx.pool(id));
@@ -1197,7 +1314,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
                         CUDA_CHECK(err);
                     } else {
                         // reset the error
-                        cudaGetLastError();
+                        (void)cudaGetLastError();
                     }
                 } else {
                     cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
@@ -1205,7 +1322,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
                         CUDA_CHECK(err);
                     } else {
                         // reset the error
-                        cudaGetLastError();
+                        (void)cudaGetLastError();
                     }
                 }
             }
@@ -1256,8 +1373,6 @@ static void ggml_cuda_op_mul_mat(
     const int64_t ne13 = src1->ne[3];
     const int64_t nrows1 = ggml_nrows(src1);
 
-    GGML_ASSERT(ne03 == ne13);
-
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
@@ -1271,9 +1386,11 @@ static void ggml_cuda_op_mul_mat(
 
     GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
 
-    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
+    GGML_ASSERT(ne12 % ne02 == 0);
+    GGML_ASSERT(ne13 % ne03 == 0);
 
     const int64_t i02_divisor = ne12 / ne02;
+    const int64_t i03_divisor = ne13 / ne03;
 
     const size_t src0_ts = ggml_type_size(src0->type);
     const size_t src0_bs = ggml_blck_size(src0->type);
@@ -1289,6 +1406,7 @@ static void ggml_cuda_op_mul_mat(
     GGML_ASSERT(!(split && ne02 > 1));
     GGML_ASSERT(!(split && ne03 > 1));
     GGML_ASSERT(!(split && ne02 < ne12));
+    GGML_ASSERT(!(split && ne03 < ne13));
 
     ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
 
@@ -1369,12 +1487,7 @@ static void ggml_cuda_op_mul_mat(
             const size_t nbytes_data    = ggml_nbytes(src0);
             const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
             dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
-        // TODO: remove this for MUSA once the Guilty Lockup issue is resolved
-#ifndef GGML_USE_MUSA
             CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
-#else // GGML_USE_MUSA
-            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
-#endif // !GGML_USE_MUSA
         }
 
         // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
@@ -1452,7 +1565,8 @@ static void ggml_cuda_op_mul_mat(
                 }
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  *  src0_dd_i =  dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+                const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs;
+                char  *  src0_dd_i =  dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix;
                 float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
                 char  * src1_ddq_i = dev[id].src1_ddq +  src1_ddq_i_offset;
                 float *   dst_dd_i =   dev[id].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@@ -1496,8 +1610,9 @@ static void ggml_cuda_op_mul_mat(
                     CUDA_CHECK(cudaGetLastError());
                 }
 
-                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
-                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
+                if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+                        src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
                 }
 
                 // do the computation
@@ -1613,10 +1728,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
     cudaDataType_t      cu_data_type    = CUDA_R_16F;
 
-    if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
-        cu_compute_type = CUBLAS_COMPUTE_32F;
-    }
-
     // dst strides
     size_t nbd2 = dst->nb[2];
     size_t nbd3 = dst->nb[3];
@@ -1645,6 +1756,12 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         beta  = &beta_f32;
     }
 
+    if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) {
+        cu_compute_type = CUBLAS_COMPUTE_32F;
+        alpha = &alpha_f32;
+        beta  = &beta_f32;
+    }
+
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
 
@@ -1672,9 +1789,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         }
     }
 #else
-#ifdef GGML_USE_MUSA
-    GGML_ASSERT(false);
-#else // !GGML_USE_MUSA
     if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         // use cublasGemmStridedBatchedEx
@@ -1717,7 +1831,6 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
     }
-#endif // GGML_USE_MUSA
 #endif
 
     if (dst->op_params[0] == GGML_PREC_DEFAULT) {
@@ -1752,14 +1865,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
 
             const int cc              = ggml_cuda_info().devices[id].cc;
             use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_available(cc);
-            any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
+            any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
+            any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
         }
     } else {
         const int cc              = ggml_cuda_info().devices[ctx.device].cc;
         use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_available(cc);
-        any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
+        any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
+        any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
     }
 
     // debug helpers
@@ -1770,7 +1883,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
+    if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
         ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
@@ -2003,6 +2116,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_GET_ROWS:
             ggml_cuda_op_get_rows(ctx, dst);
             break;
+        case GGML_OP_GET_ROWS_BACK:
+            ggml_cuda_op_get_rows_back(ctx, dst);
+            break;
         case GGML_OP_DUP:
             ggml_cuda_dup(ctx, dst);
             break;
@@ -2094,16 +2210,17 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_LEAKY_RELU:
             ggml_cuda_op_leaky_relu(ctx, dst);
             break;
+        case GGML_OP_SILU_BACK:
+            ggml_cuda_op_silu_back(ctx, dst);
+            break;
         case GGML_OP_RMS_NORM:
             ggml_cuda_op_rms_norm(ctx, dst);
             break;
+        case GGML_OP_RMS_NORM_BACK:
+            ggml_cuda_op_rms_norm_back(ctx, dst);
+            break;
         case GGML_OP_MUL_MAT:
-            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
-                GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
-                return false;
-            } else {
-                ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
-            }
+            ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
             break;
         case GGML_OP_MUL_MAT_ID:
             ggml_cuda_mul_mat_id(ctx, dst);
@@ -2141,9 +2258,15 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SOFT_MAX:
             ggml_cuda_op_soft_max(ctx, dst);
             break;
+        case GGML_OP_SOFT_MAX_BACK:
+            ggml_cuda_op_soft_max_back(ctx, dst);
+            break;
         case GGML_OP_ROPE:
             ggml_cuda_op_rope(ctx, dst);
             break;
+        case GGML_OP_ROPE_BACK:
+            ggml_cuda_op_rope_back(ctx, dst);
+            break;
         case GGML_OP_IM2COL:
             ggml_cuda_op_im2col(ctx, dst);
             break;
@@ -2173,6 +2296,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_RWKV_WKV6:
             ggml_cuda_op_rwkv_wkv6(ctx, dst);
             break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            ggml_cuda_op_gated_linear_attn(ctx, dst);
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
             break;
@@ -2291,6 +2417,66 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
 }
 
 #ifdef USE_CUDA_GRAPH
+static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+    std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
+
+    // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+    cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        ggml_tensor * node = cgraph->nodes[i];
+
+        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+            continue;
+        }
+
+        if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
+            use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
+#endif
+        }
+
+        if (node->op == GGML_OP_MUL_MAT_ID) {
+            use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+#endif
+        }
+
+        if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
+            // disable CUDA graphs for batch size > 1 for now.
+            // Changes in batch size or context size can cause changes to the grid size of some kernels.
+            use_cuda_graph = false;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+        }
+
+        if (node->op == GGML_OP_CPY) {
+            // store the copy op parameter which changes with each token.
+            cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+            // store a pointer to each copy op CUDA kernel to identify it later
+            void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+            if (!ptr) {
+                use_cuda_graph = false;
+#ifndef NDEBUG
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
+#endif
+            } else {
+                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
+                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
+                }
+            }
+        }
+
+        if (!use_cuda_graph) {
+            break;
+        }
+    }
+
+    return use_cuda_graph;
+}
+
 static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
     graph_node_properties->node_address = node->data;
     graph_node_properties->node_op = node->op;
@@ -2341,149 +2527,111 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
 
     return true;
 }
-#endif
 
-static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
 
-    ggml_cuda_set_device(cuda_ctx->device);
+    if (cuda_graph_update_required) {
+        // Extract nodes from graph
+        // First call with null argument gets number of nodes in graph
+        CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
+        // Subsequent call with non-null argument gets nodes
+        cuda_ctx->cuda_graph->nodes.clear();
+        cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+        cuda_ctx->cuda_graph->params.clear();
+        cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
+        if (cuda_ctx->cuda_graph->num_nodes > 0) {
+            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
 
-#ifdef USE_CUDA_GRAPH
-    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
-
-    // Objects required for CUDA Graph
-    if (cuda_ctx->cuda_graph == nullptr) {
-        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
-    }
-
-    bool use_cuda_graph = true;
-    bool cuda_graph_update_required = false;
-    // vector of pointers to CUDA cpy kernels, which are required to identify
-    // kernel parameters which need updated in the graph for each token
-    std::vector ggml_cuda_cpy_fn_ptrs;
-
-    if (cuda_ctx->cuda_graph->graph == nullptr) {
-        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
-            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
-#endif
-        }
-    }
-
-    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
-    // or previous graph capture failure.
-    // Also disable for multi-gpu for now. TO DO investigate
-    if (disable_cuda_graphs_due_to_env
-        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
-        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
-        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
-        use_cuda_graph = false;
-    }
-
-    if (use_cuda_graph) {
-        if (cuda_ctx->cuda_graph->instance == nullptr) {
-            cuda_graph_update_required = true;
-        }
-
-        // Check if the graph size has changed
-        if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
-            cuda_graph_update_required = true;
-            cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
-        }
-
-        // Loop over nodes in GGML graph to determine if CUDA graph update is required
-        // and store properties to allow this comparison for the next token
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            bool has_matching_properties = true;
-            if (!cuda_graph_update_required) {
-                has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
-            }
-            if (!has_matching_properties) {
-                cuda_graph_update_required = true;
-            }
-            set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
-        }
-
-        // Loop over nodes in GGML graph to obtain info needed for CUDA graph
-        cuda_ctx->cuda_graph->updated_kernel_arg.clear();
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            ggml_tensor * node = cgraph->nodes[i];
-
-            if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
-                continue;
-            }
-
-            if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
-                use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
-#endif
-            }
-
-            if (node->op == GGML_OP_MUL_MAT_ID) {
-                use_cuda_graph = false; // This node type is not supported by CUDA graph capture
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
-#endif
-            }
-
-            if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
-                // disable CUDA graphs for batch size > 1 for now.
-                // Changes in batch size or context size can cause changes to the grid size of some kernels.
-                use_cuda_graph = false;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
-#endif
-            }
-
-            if (node->op == GGML_OP_CPY) {
-                // store the copy op parameter which changes with each token.
-                cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
-                // store a pointer to each copy op CUDA kernel to identify it later
-                void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
-                if (!ptr) {
-                    use_cuda_graph = false;
-#ifndef NDEBUG
-                    GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
-#endif
-                } else {
-                    if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
-                        ggml_cuda_cpy_fn_ptrs.push_back(ptr);
+            // Loop over nodes, and extract kernel parameters from each node
+            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+                cudaGraphNodeType node_type;
+                CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
+                if (node_type == cudaGraphNodeTypeKernel) {
+                    cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
+                    if (stat == cudaErrorInvalidDeviceFunction) {
+                        // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
+                        // We don't need to update blas nodes, so clear error and move on.
+                        (void)cudaGetLastError();
+                    } else {
+                        GGML_ASSERT(stat == cudaSuccess);
                     }
                 }
             }
-
-            if (!use_cuda_graph) {
-                break;
+        }
+    } else {
+        // One of the arguments to the copy kernel is updated for each token, hence we need to
+        // replace that argument with the updated value in the CUDA graph
+        // on update steps, the live parameters will already be captured
+        int k = 0;
+        for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+            if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
+                char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
+                cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
+                CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
             }
         }
-
-        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
-        if (use_cuda_graph && cuda_graph_update_required) {
-            cuda_ctx->cuda_graph->number_consecutive_updates++;
-        } else {
-            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
-        }
-
-        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
-            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
-#endif
-        }
     }
+}
 
-    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
-        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
-    }
+static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
 
-#else
-    bool use_cuda_graph = false;
     bool cuda_graph_update_required = false;
-#endif // USE_CUDA_GRAPH
 
-    bool graph_evaluated_or_captured = false;
+    if (cuda_ctx->cuda_graph->instance == nullptr) {
+        cuda_graph_update_required = true;
+    }
+
+    // Check if the graph size has changed
+    if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+        cuda_graph_update_required = true;
+        cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
+    }
+
+    // Loop over nodes in GGML graph to determine if CUDA graph update is required
+    // and store properties to allow this comparison for the next token
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        bool has_matching_properties = true;
+        if (!cuda_graph_update_required) {
+            has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+        }
+        if (!has_matching_properties) {
+            cuda_graph_update_required = true;
+        }
+        set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+    }
+
+    return cuda_graph_update_required;
+}
+
+static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
+
+    cudaGraphExecUpdateResultInfo result_info;
+#ifdef __HIP_PLATFORM_AMD__
+    hipGraphNode_t errorNode;
+    hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
+#else
+    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+#endif
+    if (stat == cudaErrorGraphExecUpdateFailure) {
+#ifndef NDEBUG
+        GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
+#endif
+
+        // The pre-existing graph exec cannot be updated due to violated constraints
+        // so instead clear error and re-instantiate
+        (void)cudaGetLastError();
+        CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
+        cuda_ctx->cuda_graph->instance = nullptr;
+        CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+    } else {
+        GGML_ASSERT(stat == cudaSuccess);
+    }
+}
+#endif
+
+static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+   [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs,  bool & graph_evaluated_or_captured, bool & use_cuda_graph,
+    bool & cuda_graph_update_required) {
 
     while (!graph_evaluated_or_captured) {
         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@@ -2521,19 +2669,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
                 CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
                 cuda_ctx->cuda_graph->graph = nullptr;
             }
-            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
 
-#if 0
-            if (disable_cuda_graphs_due_to_failed_capture) {
-                use_cuda_graph = false;
-                cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
-#endif
-            } else {
-                graph_evaluated_or_captured = true; // CUDA graph has been captured
-            }
-#endif
+            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
             graph_evaluated_or_captured = true; // CUDA graph has been captured
         } else {
             graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
@@ -2546,72 +2683,91 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         }
 
         // Perform update to graph (if required for this token), and change copy parameter (required for every token)
-
-        if (cuda_graph_update_required) {
-            // Extract nodes from graph
-            // First call with null argument gets number of nodes in graph
-            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
-            // Subsequent call with non-null argument gets nodes
-            cuda_ctx->cuda_graph->nodes.clear();
-            cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
-            cuda_ctx->cuda_graph->params.clear();
-            cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
-            if (cuda_ctx->cuda_graph->num_nodes > 0) {
-                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
-
-                // Loop over nodes, and extract kernel parameters from each node
-                for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                    cudaGraphNodeType node_type;
-                    CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
-                    if (node_type == cudaGraphNodeTypeKernel) {
-                        cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
-                        if (stat == cudaErrorInvalidDeviceFunction) {
-                            // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
-                            // We don't need to update blas nodes, so clear error and move on.
-                            cudaGetLastError();
-                        } else {
-                            GGML_ASSERT(stat == cudaSuccess);
-                        }
-                    }
-                }
-            }
-        }
-
-        // One of the arguments to the copy kernel is updated for each token, hence we need to
-        // replace that argument with the updated value in the CUDA graph
-        if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
-            int k = 0;
-            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
-                    char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
-                    cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
-                    CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
-                }
-            }
-        }
+        maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
 
         // Update graph executable
-        cudaGraphExecUpdateResultInfo result_info;
-        cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
-        if (stat == cudaErrorGraphExecUpdateFailure) {
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
-#endif
-            // The pre-existing graph exec cannot be updated due to violated constraints
-            // so instead clear error and re-instantiate
-            cudaGetLastError();
-            CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
-            cuda_ctx->cuda_graph->instance = nullptr;
-            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
-        } else {
-            GGML_ASSERT(stat == cudaSuccess);
-        }
+        update_cuda_graph_executable(cuda_ctx);
+
         // Launch graph
         CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
 #else
         graph_evaluated_or_captured = true;
-#endif // USE_CUDA_GRAPH
+#endif  // USE_CUDA_GRAPH
     }
+}
+
+static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    ggml_cuda_set_device(cuda_ctx->device);
+
+    // vector of pointers to CUDA cpy kernels, which are required to identify
+    // kernel parameters which need updated in the graph for each token
+    std::vector ggml_cuda_cpy_fn_ptrs;
+
+#ifdef USE_CUDA_GRAPH
+    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+
+    // Objects required for CUDA Graph
+    if (cuda_ctx->cuda_graph == nullptr) {
+        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+    }
+
+    bool use_cuda_graph = true;
+    bool cuda_graph_update_required = false;
+
+    if (cuda_ctx->cuda_graph->graph == nullptr) {
+        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
+            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+#endif
+        }
+    }
+
+    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
+    // or previous graph capture failure.
+    // Also disable for multi-gpu for now. TO DO investigate
+    if (disable_cuda_graphs_due_to_env
+        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
+        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
+        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
+        use_cuda_graph = false;
+    }
+
+    if (use_cuda_graph) {
+        cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
+
+        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
+                             ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
+
+        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
+        if (use_cuda_graph && cuda_graph_update_required) {
+            cuda_ctx->cuda_graph->number_consecutive_updates++;
+        } else {
+            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
+        }
+
+        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
+            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+#endif
+        }
+    }
+
+    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
+        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+    }
+
+#else
+    bool use_cuda_graph = false;
+    bool cuda_graph_update_required = false;
+#endif // USE_CUDA_GRAPH
+
+    bool graph_evaluated_or_captured = false;
+
+    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
 
     return GGML_STATUS_SUCCESS;
 }
@@ -2687,11 +2843,11 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
         return false;
     }
 
-#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
+#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
     cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
     if (err != cudaSuccess) {
         // clear the error
-        cudaGetLastError();
+        (void)cudaGetLastError();
 
         GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
                            size / 1024.0 / 1024.0, cudaGetErrorString(err));
@@ -2699,8 +2855,10 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
     }
     return true;
 #else
+    GGML_UNUSED(buffer);
+    GGML_UNUSED(size);
     return false;
-#endif
+#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
 }
 
 void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
@@ -2711,7 +2869,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
     cudaError_t err = cudaHostUnregister(buffer);
     if (err != cudaSuccess) {
         // clear the error
-        cudaGetLastError();
+        (void)cudaGetLastError();
     }
 }
 
@@ -2843,9 +3001,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
                     return false;
                 }
-                if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
-                    return false;
-                }
 #ifdef GGML_USE_MUSA
                 if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
                     !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
@@ -2887,7 +3042,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 }
             } break;
         case GGML_OP_OUT_PROD:
-            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
+            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_GET_ROWS:
             {
                 switch (op->src[0]->type) {
@@ -2903,6 +3058,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                         return false;
                 }
             } break;
+        case GGML_OP_GET_ROWS_BACK:
+            {
+                return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
+            } break;
         case GGML_OP_CPY:
             {
                 ggml_type src0_type = op->src[0]->type;
@@ -2922,15 +3081,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
                     return true;
                 }
@@ -2961,7 +3132,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
             } break;
         case GGML_OP_REPEAT_BACK:
-                return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
+                return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
         case GGML_OP_CONCAT:
             {
                 ggml_type src0_type = op->src[0]->type;
@@ -2976,8 +3147,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 }
                 return false;
             } break;
+        case GGML_OP_SILU_BACK:
+            return ggml_is_contiguous(op->src[0]);
+            break;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
+            return true;
+        case GGML_OP_RMS_NORM_BACK:
             return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
             break;
         case GGML_OP_NONE:
@@ -3002,15 +3178,26 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
             return true;
+        case GGML_OP_SOFT_MAX_BACK: {
+            float max_bias = 0.0f;
+            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
+            return max_bias == 0.0f;
+        }
         case GGML_OP_ROPE:
-            return ggml_is_contiguous(op->src[0]);
+        case GGML_OP_ROPE_BACK: {
+            const size_t ts = ggml_type_size(op->src[0]->type);
+            const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
+            return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
+        }
         case GGML_OP_IM2COL:
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_ARGSORT:
         case GGML_OP_ACC:
+            return true;
         case GGML_OP_GROUP_NORM:
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
         case GGML_OP_UNPAD:
@@ -3018,11 +3205,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
             return true;
         case GGML_OP_FLASH_ATTN_EXT: {
 #ifndef FLASH_ATTN_AVAILABLE
             return false;
-#endif
+#endif // FLASH_ATTN_AVAILABLE
             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
                 return false;
             }
@@ -3035,8 +3223,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
-            const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
-            return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+            return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
+                op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
         }
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
@@ -3059,6 +3247,7 @@ static int64_t get_op_batch_size(const ggml_tensor * op) {
             return op->ne[1];
         case GGML_OP_MUL_MAT_ID:
         case GGML_OP_ROPE:
+        case GGML_OP_ROPE_BACK:
             return op->ne[2];
         default:
             return ggml_nrows(op);
@@ -3161,7 +3350,7 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t
         features.push_back({ "FORCE_CUBLAS", "1" });
     #endif
 
-    #ifdef GGML_CUDA_NO_VMM
+    #ifndef GGML_USE_VMM
         features.push_back({ "NO_VMM", "1" });
     #endif
 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/gla.cu b/ml/backend/ggml/ggml/src/ggml-cuda/gla.cu
new file mode 100644
index 00000000..f7d615a8
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/gla.cu
@@ -0,0 +1,93 @@
+#include "common.cuh"
+#include "gla.cuh"
+
+template
+static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
+     const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = HEAD_SIZE;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _k[head_size], _r[head_size], _td[head_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        __syncthreads();
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4 & k = (float4 &)(_k[j]);
+            const float4 & r = (float4 &)(_r[j]);
+            const float4 & td = (float4 &)(_td[j]);
+            float4 & s = (float4 &)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            s.x = s.x * td.x + kv.x;
+            s.y = s.y * td.y + kv.y;
+            s.z = s.z * td.z + kv.z;
+            s.w = s.w * td.w + kv.w;
+
+            y += r.x * s.x;
+            y += r.y * s.y;
+            y += r.z * s.z;
+            y += r.w * s.w;
+        }
+        dst[t] = y * scale;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * k_d  = (const float *)dst->src[0]->data;
+    const float * v_d  = (const float *)dst->src[1]->data;
+    const float * r_d  = (const float *)dst->src[2]->data;
+    const float * td_d = (const float *)dst->src[3]->data;
+    const float * s_d  = (const float *)dst->src[4]->data;
+
+    const int64_t B = dst->src[4]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    float scale;
+    memcpy(&scale, (float*)dst->op_params, sizeof(float));
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == 64 || C / H == 128);
+
+
+    if (C / H == 64) {
+        gated_linear_attn_f32<64><<>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+    } else {
+        gated_linear_attn_f32<128><<>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/gla.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/gla.cuh
new file mode 100644
index 00000000..2c82ad7d
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/gla.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
index 7d11540a..9206bfeb 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
@@ -1,221 +1,394 @@
+// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
+// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
+// The documentation for the PTX instructions can be found under:
+//   https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
+//
+// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
+// A is a row-major matrix with shape M x K.
+// B is a column-major matrix with shape K x N.
+// C is a column-major matrix with shape M x N.
+// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
+// Note that J is measured in physical 32 bit elements instead of logical elements.
+// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
+// All matrix tiles have ne physical 32 bit elements per warp.
+//
+// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
+
 #include "common.cuh"
 
-struct mma_int_A_I16K4 {
-    static constexpr int I  = 16;
-    static constexpr int K  = 4;
-    static constexpr int ne = 2;
 
-    int x[ne] = {0};
+#if CUDART_VERSION >= 11080
 
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = (l%2) * (I/2) + threadIdx.x / K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
+static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
+    int ret = 0;
+
+#ifdef NEW_MMA_AVAILABLE
+    asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
+        : "=r"(ret) : "r"(x));
+#else
+    NO_DEVICE_CODE;
+#endif // defined(NEW_MMA_AVAILABLE)
+    return ret;
+}
+
+#else
+
+static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
+    // Imagine transposing row-major matrix to column-major matrix.
+    const int src_i_low  = 2 * (threadIdx.x % 4);
+    const int src_i_high = src_i_low + 1;
+    const int src_j      = threadIdx.x / 4;
+
+    const int src_laneid_low  = src_i_low  * 4 + src_j / 2;
+    const int src_laneid_high = src_i_high * 4 + src_j / 2;
+
+    const int shift_low  = ((src_j + 0) % 2) * 16;
+    const int shift_high = ((src_j + 1) % 2) * 16;
+
+    const int ret_low  = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low,  WARP_SIZE) >> shift_low)  & 0x0000FFFF;
+    const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
+
+    return ret_low | ret_high;
+}
+
+#endif // CUDART_VERSION >= 11080
+
+static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
+    half2 ret;
+    *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
+    return ret;
+}
+
+namespace ggml_cuda_mma {
+
+    template 
+    struct tile {
+        static constexpr int I  = I_;
+        static constexpr int J  = J_;
+        static constexpr int ne = I * J / WARP_SIZE;
+        T x[ne] = {0};
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && (J == 4 || J == 8)) {
+                return threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l / 2) * 8 + threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 16) {
+                return ((l / 2) % 2) * 8 + threadIdx.x / 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 4) {
+                return threadIdx.x % 4;
+            } else if constexpr (I == 8 && J == 8) {
+                return 4 * l + threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return 2 * (threadIdx.x % 4) + l % 2;
+            } else if constexpr (I == 16 && J == 16) {
+                return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+    };
+
+    template 
+    struct tile {
+        static constexpr int I  = I_;
+        static constexpr int J  = J_;
+        static constexpr int ne = I * J / WARP_SIZE;
+        half2 x[ne] = {{0.0f, 0.0f}};
+
+        static __device__ __forceinline__ int get_i(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return l * 8 + threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l % 2) * 8 + threadIdx.x / 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+
+        static __device__ __forceinline__ int get_j(const int l) {
+            if constexpr (I == 8 && J == 8) {
+                return l * 4 + threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 4) {
+                return threadIdx.x % 4;
+            } else if constexpr (I == 16 && J == 8) {
+                return (l / 2) * 4 + threadIdx.x % 4;
+            } else {
+                static_assert(I == -1 && J == -1, "template specialization not implemented");
+            }
+        }
+    };
+
+    template 
+    static __device__ __forceinline__ tile get_half2(const tile & tile_float) {
+        tile ret;
+#pragma unroll
+        for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
+            ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+        }
         return ret;
     }
 
-    static __device__ __forceinline__ int get_k(const int /* l */) {
-        const int ret = threadIdx.x % K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
+    static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
+        tile<8, 8, half2> ret;
+        ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
+        ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
+
         return ret;
     }
 
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE)
-        const int * xs = xs0 + (threadIdx.x%I)*stride;
-        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
-            : "+r"(x[0]), "+r"(x[1])
+    template 
+    static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) {
+#pragma unroll
+        for (int l = 0; l < t.ne; ++l) {
+            t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
+        }
+    }
+
+    template 
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int *) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
+        asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "=r"(xi[0]), "=r"(xi[1])
             : "l"(xs));
 #else
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_i(l)*stride + get_k(l)];
-        }
-#endif // defined(INT8_MMA_AVAILABLE)
-    }
-};
-
-struct mma_int_A_I16K8 {
-    static constexpr int I  = 16;
-    static constexpr int K  = 8;
-    static constexpr int ne = 4;
-
-    int x[ne] = {0};
-
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
-        return ret;
+        load_generic(t, xs0, stride);
+#endif // NEW_MMA_AVAILABLE
     }
 
-    static __device__ __forceinline__ int get_k(const int l) {
-        const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
-        return ret;
-    }
-
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE)
-        const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
-        asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
-            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
+    template 
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int *) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
+        asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+            : "=r"(xi[0]), "=r"(xi[1])
             : "l"(xs));
 #else
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_i(l)*stride + get_k(l)];
-        }
-#endif // defined(INT8_MMA_AVAILABLE)
+        load_generic(xs0, stride);
+#endif // NEW_MMA_AVAILABLE
     }
 
-    __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
-        ((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
-    }
-};
-
-struct mma_int_B_J8K4 {
-    static constexpr int J  = 8;
-    static constexpr int K  = 4;
-    static constexpr int ne = 1;
-
-    int x[ne] = {0};
-
-    static __device__ __forceinline__ int get_j(const int /* l */) {
-        const int ret = threadIdx.x / K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
-    }
-
-    static __device__ __forceinline__ int get_k(const int /* l */) {
-        const int ret = threadIdx.x % K;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
-        return ret;
-    }
-
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
-        const int * xs = xs0 + (threadIdx.x%J)*stride;
-        asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
-            : "+r"(x[0])
+    template 
+    static __device__ __forceinline__ void load_ldmatrix(
+            tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int * ) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
+        asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+            : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
             : "l"(xs));
 #else
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_j(l)*stride + get_k(l)];
-        }
-#endif // defined(INT8_MMA_AVAILABLE)
-    }
-};
-
-struct mma_int_B_J8K8 {
-    static constexpr int J  = 8;
-    static constexpr int K  = 8;
-    static constexpr int ne = 2;
-
-    int x[ne] = {0};
-
-    static __device__ __forceinline__ int get_j(const int /* l */) {
-        const int ret = threadIdx.x / (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
+        load_generic(t, xs0, stride);
+#endif // NEW_MMA_AVAILABLE
     }
 
-    static __device__ __forceinline__ int get_k(const int l) {
-        const int ret = l * (K/2) + threadIdx.x % (K/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  K);
-        return ret;
-    }
-
-    __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
-#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
-        const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
-        asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
-            : "+r"(x[0]), "+r"(x[1])
+    template 
+    static __device__ __forceinline__ void load_ldmatrix_trans(
+            tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef NEW_MMA_AVAILABLE
+        int * xi = (int * ) t.x;
+        const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
+        asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
+            : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
             : "l"(xs));
 #else
-#pragma unroll
-        for (int l = 0; l < ne; ++l) {
-            x[l] = xs0[get_j(l)*stride + get_k(l)];
-        }
-#endif // defined(INT8_MMA_AVAILABLE)
-    }
-};
-
-struct mma_int_C_I16J8 {
-    static constexpr int I  = 16;
-    static constexpr int J  = 8;
-    static constexpr int ne = 4;
-
-    int x[ne] = {0};
-
-    static __device__ __forceinline__ int get_i(const int l) {
-        const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  I);
-        return ret;
+        GGML_UNUSED(t);
+        GGML_UNUSED(xs0);
+        GGML_UNUSED(stride);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
     }
 
-    static __device__ __forceinline__ int get_j(const int l) {
-        const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
-        GGML_CUDA_ASSUME(ret >= 0);
-        GGML_CUDA_ASSUME(ret <  J);
-        return ret;
-    }
-
-    __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
-#ifdef INT8_MMA_AVAILABLE
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
+#ifdef NEW_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
-            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+            : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
 #else
         // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[0]), "+r"(x[1])
-            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+            : "+r"(D.x[0]), "+r"(D.x[1])
+            : "r"(A.x[0]), "r"(B.x[0]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+            : "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[1]), "r"(B.x[0]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #else
-        GGML_UNUSED(mma_A);
-        GGML_UNUSED(mma_B);
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-    __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
-#ifdef INT8_MMA_AVAILABLE
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
+#ifdef NEW_MMA_AVAILABLE
 #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
-            : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
+            : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
 #else
         // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[0]), "+r"(x[1])
-            : "r"(mma_A.x[0]), "r"(mma_B.x[0]));
+            : "+r"(D.x[0]), "+r"(D.x[1])
+            : "r"(A.x[0]), "r"(B.x[0]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
+            : "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[1]), "r"(B.x[0]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[0]), "+r"(x[1])
-            : "r"(mma_A.x[2]), "r"(mma_B.x[1]));
+            : "+r"(D.x[0]), "+r"(D.x[1])
+            : "r"(A.x[2]), "r"(B.x[1]));
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
-            : "+r"(x[2]), "+r"(x[3])
-            : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
+            : "+r"(D.x[2]), "+r"(D.x[3])
+            : "r"(A.x[3]), "r"(B.x[1]));
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #else
-        GGML_UNUSED(mma_A);
-        GGML_UNUSED(mma_B);
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
         NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
-};
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
+#else
+        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+        // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
+    static __device__ __forceinline__ void mma(
+            tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
+#else
+        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
index 270251df..10f2ebb1 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
@@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q(
     const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
 
     int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
+    const int cc = ggml_cuda_info().devices[id].cc;
 
     // the main device has a larger memory buffer to hold the results from all GPUs
     // nrows_dst == nrows of the matrix that the kernel writes into
@@ -27,7 +27,8 @@ void ggml_cuda_op_mul_mat_q(
     // The stream-k decomposition is only faster for recent NVIDIA GPUs.
     // Also its fixup needs to allocate a temporary buffer in the memory pool.
     // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
-    const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
+    const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
+        cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
     const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
 
     switch (src0->type) {
@@ -132,11 +133,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return false;
     }
 
-    if (int8_mma_available(cc)) {
+    if (new_mma_available(cc)) {
         return true;
     }
 
-    if (cc < GGML_CUDA_CC_DP4A) {
+    if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
         return false;
     }
 
@@ -145,8 +146,8 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
 #endif //GGML_CUDA_FORCE_MMQ
 
     if (cc < GGML_CUDA_CC_OFFSET_AMD) {
-        return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+        return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
-    return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
index 3cd508a1..0451c65f 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
@@ -7,6 +7,8 @@
 #include 
 #include 
 
+using namespace ggml_cuda_mma;
+
 #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
 #define MMQ_ITER_K 256
 #define MMQ_NWARPS 8
@@ -86,19 +88,20 @@ struct tile_x_sizes {
     int sc;
 };
 
-static constexpr int get_mmq_x_max_host(const int cc) {
-    return int8_mma_available(cc) ? 128 :
+static int get_mmq_x_max_host(const int cc) {
+    return new_mma_available(cc) ? 128 :
+        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
 #ifdef GGML_CUDA_FORCE_MMQ
-        cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128                     : 64;
+            128                     : 64;
 #else
-        cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
+            MMQ_DP4A_MAX_BATCH_SIZE : 64;
 #endif // GGML_CUDA_FORCE_MMQ
 }
 
 static constexpr __device__ int get_mmq_x_max_device() {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     return 128;
-#else // INT8_MMA_AVAILABLE
+#else // NEW_MMA_AVAILABLE
 
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
     return 128;
@@ -116,11 +119,12 @@ static constexpr __device__ int get_mmq_x_max_device() {
 #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
-static constexpr int get_mmq_y_host(const int cc) {
-    return cc >= GGML_CUDA_CC_OFFSET_AMD ? (cc == GGML_CUDA_CC_RDNA1 ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
+static int get_mmq_y_host(const int cc) {
+    return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
+        (ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
 }
 
 static constexpr __device__ int get_mmq_y_device() {
@@ -209,10 +213,10 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
 #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
 
 static int mmq_get_granularity_host(const int mmq_x, const int cc) {
-    return int8_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
+    return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8;
 }
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
     return mmq_x >= 48 ? 16 : 8;
 }
@@ -220,21 +224,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
 static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) {
     return 8;
 }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 // ------------------------------------------------------------
 
 template  static __device__ __forceinline__ void load_tiles_q4_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_0;
     const int kqsx = threadIdx.x % QI4_0;
@@ -250,12 +254,12 @@ template  static __device__ __forceinlin
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b2(bxi->qs, kqsx);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0]     = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
 #else
         x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
@@ -271,11 +275,11 @@ template  static __device__ __forceinlin
 
         const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -322,14 +326,14 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
 template  static __device__ __forceinline__ void load_tiles_q4_1(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_1;
     const int kqsx = threadIdx.x % QI4_1;
@@ -345,12 +349,12 @@ template  static __device__ __forceinlin
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
         const int qs0 = get_int_b4(bxi->qs, kqsx);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0]     = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
         x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
@@ -366,11 +370,11 @@ template  static __device__ __forceinlin
 
         const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
 #else
         x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -417,14 +421,14 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
 template  static __device__ __forceinline__ void load_tiles_q5_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI5_0;
     const int kqsx = threadIdx.x % QI5_0;
@@ -456,13 +460,13 @@ template  static __device__ __forceinlin
         qs1    |= (qh <<  9)   & 0x10000000;  // 19 -> 28
         qs1     = __vsubss4(qs1, 0x10101010); // subtract 16
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + 0]     = qs0;
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
@@ -478,25 +482,25 @@ template  static __device__ __forceinlin
 
         const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_q5_1(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI5_1;
     const int kqsx = threadIdx.x % QI5_1;
@@ -526,13 +530,13 @@ template  static __device__ __forceinlin
         qs1    |= (qh <<  2) & 0x00100000; // 18 -> 20
         qs1    |= (qh <<  9) & 0x10000000; // 19 -> 28
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + 0]     = qs0;
         x_qs[i*(2*WARP_SIZE + 1)     + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
@@ -548,25 +552,25 @@ template  static __device__ __forceinlin
 
         const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_dm[i*MMQ_MMA_TILE_X_K_Q8_1       + kbxd] = bxi->dm;
 #else
         x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_q8_0(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_tile + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI8_0;
     const int kqsx = threadIdx.x % QI8_0;
@@ -581,13 +585,13 @@ template  static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + 0         + threadIdx.x] = get_int_b2(bxi[0].qs,               kqsx);
         x_qs[i*(2*WARP_SIZE + 1)     + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0;
@@ -603,11 +607,11 @@ template  static __device__ __forceinlin
 
         const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0             + kbxd] = bxi->d;
 #else
         x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -645,15 +649,15 @@ template 
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 8, int> tile_A;
+    typedef tile< 8, 8, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const float * x_df = (const float *) x_qs + 2*WARP_SIZE;
@@ -661,8 +665,8 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     const float * y_df = (const float *) y;
     const half2 * y_ds = (const half2 *) y;
 
-    mma_A A[ntx][WARP_SIZE/QI8_0];
-    float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0];
+    tile_A A[ntx][WARP_SIZE/QI8_0];
+    float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0];
 
     const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
@@ -672,12 +676,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
             const int k0 = k00 + k01;
 
-            A[n][k01/QI8_0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+            load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
@@ -689,17 +693,17 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) {
-            mma_B  B;
-            float dB[mma_C::ne/2];
+            tile_B B;
+            float dB[tile_C::ne/2];
 
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
                     dB[l] =             y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
@@ -710,12 +714,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C;
-                C.mma_K8(A[n][k01/QI8_0], B);
+                tile_C C;
+                mma(C, A[n][k01/QI8_0], B);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
                 }
             }
         }
@@ -756,23 +760,23 @@ template 
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
 
-    typedef mma_int_A_I16K8 mma_A;
-    typedef mma_int_B_J8K8  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 8, int> tile_A;
+    typedef tile< 8, 8, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_dm = (const half2 *) y;
 
-    mma_A    A[ntx][WARP_SIZE/QI8_1];
-    float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1];
+    tile_A   A[ntx][WARP_SIZE/QI8_1];
+    float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1];
 
     const int i0 = (threadIdx.y/ntx)*rows_per_warp;
 
@@ -782,12 +786,12 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
             const int k0 = k00 + k01;
 
-            A[n][k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+            load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_A::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
@@ -799,30 +803,30 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
-            mma_B    B;
-            float2 dsB[mma_C::ne/2];
+            tile_B   B;
+            float2 dsB[tile_C::ne/2];
 
-            B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+            load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C;
-                C.mma_K8(A[n][k01/QI8_1], B);
+                tile_C C;
+                mma(C, A[n][k01/QI8_1], B);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
                 }
             }
         }
@@ -864,28 +868,28 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
 template 
 static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_int_A_I16K4 mma_A;
-    typedef mma_int_A_I16K8 mma_A_K8;
-    typedef mma_int_B_J8K4  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 4, int> tile_A;
+    typedef tile<16, 8, int> tile_A_8;
+    typedef tile< 8, 4, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const float * x_df = (const float *) x_qs + WARP_SIZE*2;
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
 
-    mma_A   A[ntx][8];
-    float  dA[ntx][mma_C::ne/2][8];
+    tile_A  A[ntx][8];
+    float  dA[ntx][tile_C::ne/2][8];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
@@ -893,12 +897,12 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
             const int k0 = k00 + k01;
 
-            ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+            load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) {
@@ -910,31 +914,32 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
-            mma_B B[2];
-            float dB[mma_C::ne/2];
+            tile_B B[2];
+            float dB[tile_C::ne/2];
 
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+            // Here load_generic is faster than load_ldmatrix.
+            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),         MMQ_TILE_Y_K);
+            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C[2];
-                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
-                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+                tile_C C[2];
+                mma(C[0], A[n][k01/4 + 0], B[0]);
+                mma(C[1], A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
                 }
             }
         }
@@ -942,20 +947,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template  static __device__ __forceinline__ void load_tiles_q2_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % QI2_K;
 
@@ -977,11 +982,11 @@ template  static __device__ __forceinlin
 
             const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int sc_m = bxi->scales[kqsx];
@@ -992,11 +997,11 @@ template  static __device__ __forceinlin
         const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
 #endif // FAST_FP16_AVAILABLE
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
 #else
         x_dm[i*(WARP_SIZE + 1)       + kqsx] = x_dm_ik;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -1051,29 +1056,29 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
 template 
 static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_int_A_I16K4 mma_A;
-    typedef mma_int_A_I16K8 mma_A_K8;
-    typedef mma_int_B_J8K4  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 4, int> tile_A;
+    typedef tile<16, 8, int> tile_A_8;
+    typedef tile< 8, 4, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2;
     const int   * y_qs = (const int   *) y + 4;
     const half2 * y_ds = (const half2 *) y;
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
 
-    mma_A   A[ntx][8];
-    float  dA[ntx][mma_C::ne/2][8];
-    float  mA[ntx][mma_C::ne/2][8];
+    tile_A  A[ntx][8];
+    float  dA[ntx][tile_C::ne/2][8];
+    float  mA[ntx][tile_C::ne/2][8];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
@@ -1081,15 +1086,15 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
             const int k0 = k00 + k01;
 
-            ((mma_A_K8 *) A[n])[k01/QI8_1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+            load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
         }
     }
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
 #pragma unroll
             for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) {
@@ -1104,57 +1109,58 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        float2 dB[mma_C::ne/2];
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+        float2 dB[tile_C::ne/2];
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int j = j0 + mma_C::get_j(l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int j = j0 + tile_C::get_j(l);
 
             dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
         }
 
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) {
-            mma_B B[2];
+            tile_B B[2];
 
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),        MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K);
+            // Here load_generic is faster than load_ldmatrix.
+            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0),         MMQ_TILE_Y_K);
+            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
 
-            mma_C Cm[2];
+            tile_C Cm[2];
             if (k01 >= WARP_SIZE * 3/4) {
-                mma_A A1;
+                tile_A A1;
                 A1.x[0] = 0x01010101;
                 A1.x[1] = 0x01010101;
-                Cm[0].mma_K4(A1, B[0]);
-                Cm[1].mma_K4(A1, B[1]);
+                mma(Cm[0], A1, B[0]);
+                mma(Cm[1], A1, B[1]);
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C Cd[2];
+                tile_C Cd[2];
 
-                Cd[0].mma_K4(A[n][k01/4 + 0], B[0]);
-                Cd[1].mma_K4(A[n][k01/4 + 1], B[1]);
+                mma(Cd[0], A[n][k01/4 + 0], B[0]);
+                mma(Cd[1], A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
+                for (int l = 0; l < tile_C::ne; ++l) {
                     float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
                     if (k01 >= WARP_SIZE * 3/4) {
                         tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
                     }
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y);
                 }
             }
         }
 
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) {
-            float2 sB[mma_C::ne/2];
+            float2 sB[tile_C::ne/2];
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
             }
@@ -1162,9 +1168,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
-                    sum[(j0/mma_C::J + n)*mma_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
+                    sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
                 }
             }
         }
@@ -1172,13 +1178,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template  static __device__ __forceinline__ void load_tiles_q3_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
@@ -1186,7 +1192,7 @@ template  static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % QI3_K;
 
@@ -1212,11 +1218,11 @@ template  static __device__ __forceinlin
 
             const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + k] = x_qs_k;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
     }
 
@@ -1242,7 +1248,7 @@ template  static __device__ __forceinlin
 
         const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         const int8_t * sc8 = (const int8_t *) ≻
         const float d = bxi->d;
 
@@ -1252,10 +1258,10 @@ template  static __device__ __forceinlin
         }
 #else
         x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-#ifndef INT8_MMA_AVAILABLE
+#ifndef NEW_MMA_AVAILABLE
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) {
         int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y;
@@ -1268,7 +1274,7 @@ template  static __device__ __forceinlin
 
         x_df[i] = bxi->d;
     }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template 
@@ -1317,7 +1323,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
 template  static __device__ __forceinline__ void load_tiles_q4_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE);
 #else
@@ -1325,7 +1331,7 @@ template  static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1338,15 +1344,15 @@ template  static __device__ __forceinlin
         const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
         const int qs0 = get_int_b4(bxi->qs, threadIdx.x);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
 #else
         x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
@@ -1407,7 +1413,7 @@ template  static __device__ __forceinlin
 
         x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
     }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template 
@@ -1446,7 +1452,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
 template  static __device__ __forceinline__ void load_tiles_q5_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2);
 #else
@@ -1454,7 +1460,7 @@ template  static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     half2 * x_dm = (half2 *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_dm + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1478,16 +1484,16 @@ template  static __device__ __forceinlin
         const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
         const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kq0] = ql0 | qh0;
         x_qs[i*(2*WARP_SIZE + 1)     + kq1] = ql1 | qh1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) {
@@ -1548,7 +1554,7 @@ template  static __device__ __forceinlin
 
         x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8;
     }
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template 
@@ -1587,7 +1593,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
 template  static __device__ __forceinline__ void load_tiles_q6_K(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
     int   * x_sc = (int   *) (x_df + WARP_SIZE/QI6_K);
@@ -1596,7 +1602,7 @@ template  static __device__ __forceinlin
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
     int   * x_sc = (int   *) (x_df + txs.dm);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
@@ -1619,13 +1625,13 @@ template  static __device__ __forceinlin
         const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0;
         const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
         x_qs[i*(2*WARP_SIZE + 1)     + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI6_K;  // == 1 if QK_K == 256
@@ -1641,11 +1647,11 @@ template  static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q6_K       + kbxd] = bxi->d;
 #else
         x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -1658,11 +1664,11 @@ template  static __device__ __forceinlin
 
         const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
 #else
         x_sc[i*(WARP_SIZE/8) + i/8   + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8));
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -1702,17 +1708,17 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
 template 
 static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
 
-    typedef mma_int_A_I16K4 mma_A;
-    typedef mma_int_B_J8K4  mma_B;
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 4, int> tile_A;
+    typedef tile< 8, 4, int> tile_B;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K);
+    y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K);
 
     const int   * x_qs = (const int   *) x;
     const float * x_df = (const float *) x_qs + WARP_SIZE*2;
@@ -1720,11 +1726,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
     const int   * y_qs = (const int   *) y + 4;
     const float * y_df = (const float *) y;
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I);
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
 
-    mma_A   A[ntx][8];
-    int   scA[ntx][mma_C::ne/2][8];
-    float  dA[ntx][mma_C::ne/2];
+    tile_A   A[ntx][8];
+    int    scA[ntx][tile_C::ne/2][8];
+    float   dA[ntx][tile_C::ne/2];
 
 #pragma unroll
     for (int n = 0; n < ntx; ++n) {
@@ -1732,8 +1738,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
             const int k0 = k00 + k01;
 
-            A[n][k01/4 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),        MMQ_MMA_TILE_X_K_Q6_K);
-            A[n][k01/4 + 1].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + mma_A::K), MMQ_MMA_TILE_X_K_Q6_K);
+            load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0),         MMQ_MMA_TILE_X_K_Q6_K);
+            load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
         }
 
 #pragma unroll
@@ -1741,8 +1747,8 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
             const int k0 = k00 + k01;
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
                 const int      sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
                 const int8_t * sc        = (const int8_t *) &sc_packed;
@@ -1755,40 +1761,41 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
         }
 
 #pragma unroll
-        for (int l = 0; l < mma_C::ne/2; ++l) {
-            const int i = i0 + n*mma_C::I + mma_C::get_i(2*l);
+        for (int l = 0; l < tile_C::ne/2; ++l) {
+            const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
 
             dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
         }
     }
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
-        float tmp[ntx][mma_C::ne] = {{0.0f}};
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+        float tmp[ntx][tile_C::ne] = {{0.0f}};
 
 #pragma unroll
         for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) {
-            mma_B B[2];
-            float dB[mma_C::ne/2];
+            tile_B B[2];
+            float dB[tile_C::ne/2];
 
-            B[0].load(y_qs + j0*MMQ_TILE_Y_K + 0        + k01, MMQ_TILE_Y_K);
-            B[1].load(y_qs + j0*MMQ_TILE_Y_K + mma_B::K + k01, MMQ_TILE_Y_K);
+            // Here load_generic is faster than load_ldmatrix.
+            load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0         + k01, MMQ_TILE_Y_K);
+            load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
 
 #pragma unroll
-            for (int l = 0; l < mma_C::ne/2; ++l) {
-                const int j = j0 + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne/2; ++l) {
+                const int j = j0 + tile_C::get_j(l);
 
                 dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
             }
 
 #pragma unroll
             for (int n = 0; n < ntx; ++n) {
-                mma_C C[2];
-                C[0].mma_K4(A[n][k01/4 + 0], B[0]);
-                C[1].mma_K4(A[n][k01/4 + 1], B[1]);
+                tile_C C[2];
+                mma(C[0], A[n][k01/4 + 0], B[0]);
+                mma(C[1], A[n][k01/4 + 1], B[1]);
 
 #pragma unroll
-                for (int l = 0; l < mma_C::ne; ++l) {
+                for (int l = 0; l < tile_C::ne; ++l) {
                     tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
                 }
             }
@@ -1797,28 +1804,28 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
 #pragma unroll
         for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                sum[(j0/mma_C::J + n)*mma_C::ne + l] += tmp[n][l]*dA[n][l/2];
+            for (int l = 0; l < tile_C::ne; ++l) {
+                sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
             }
         }
     }
 #else
     GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
     NO_DEVICE_CODE;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq4_nl(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = threadIdx.x / QI4_NL;
     const int kqsx = threadIdx.x % QI4_NL;
@@ -1836,13 +1843,13 @@ template  static __device__ __forceinlin
         const int aux_q4 = get_int_b2(bxi->qs, kqsx);
         const int2 v = get_int_from_table_16(aux_q4);
         const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
     const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
@@ -1858,25 +1865,25 @@ template  static __device__ __forceinlin
 
         const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kbxd] = __half2float(bxi->d);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq2_xxs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI2_XXS/2);
 
@@ -1905,36 +1912,36 @@ template  static __device__ __forceinlin
             const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
             const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid0;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/4;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq2_xs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI2_XS/2);
 
@@ -1959,38 +1966,38 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq2_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI2_S/2);
 
@@ -2022,38 +2029,38 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = bxi->scales[kqsx];
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*MMQ_MMA_TILE_X_K_Q3_K               + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
 #else
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls &  0x0F)*d + d/2)/4;
         x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >>    4)*d + d/2)/4;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq3_xxs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI3_XXS/2);
 
@@ -2080,36 +2087,36 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
             const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l + 1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = aux32 >> 28;
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = (ls*d + d/2)/2;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq3_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % (QI3_S/2);
 
@@ -2143,36 +2150,36 @@ template  static __device__ __forceinlin
             const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
             const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid_l;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid_h;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
         const float d = bxi->d;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + kqsx] = ls*d;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq1_s(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     half2 * x_ds = (half2 *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kqsx = threadIdx.x % QI1_S;
 
@@ -2198,37 +2205,37 @@ template  static __device__ __forceinlin
             const int grid0 = (grid >> 0) & 0x0F0F0F0F;
             const int grid1 = (grid >> 4) & 0x0F0F0F0F;
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
 #else
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+0)] = grid0;
             x_qs[i*(2*WARP_SIZE + 1)     + 8*kqsx + (2*l+1)] = grid1;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
         }
 
         const float  d1q   = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
         const float  delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
 #else
         x_ds[i*(WARP_SIZE/4) + i/4   + kqsx] = make_half2(d1q, d1q*delta);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
 template  static __device__ __forceinline__ void load_tiles_iq4_xs(
     const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + WARP_SIZE*2);
 #else
     constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
     int   * x_qs = (int   *)  x_tile;
     float * x_df = (float *) (x_qs + txs.qs);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     const int kbx  = 0;           // threadIdx.x / QI4_XS
     const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
@@ -2246,13 +2253,13 @@ template  static __device__ __forceinlin
         const int aux_q4 = get_int_b4(bxi->qs, kqsx);
         const int2 v = get_int_from_table_16(aux_q4);
         const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
         x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
 #else
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 0] = v.x;
         x_qs[i*(2*WARP_SIZE + 1)     + k0 + 4] = v.y;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 
 #pragma unroll
@@ -2270,11 +2277,11 @@ template  static __device__ __forceinlin
         const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
             | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
         x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
 #else
         x_df[i*(WARP_SIZE/4) + i/4   + threadIdx.x % 8] = d * (ls - 32);
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
     }
 }
 
@@ -2307,36 +2314,36 @@ template
 static __device__ __forceinline__ void mmq_write_back_mma(
     const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
 
-    typedef mma_int_C_I16J8 mma_C;
+    typedef tile<16, 8, int> tile_C;
 
     constexpr int granularity = mmq_get_granularity_device(mmq_x);
     constexpr int rows_per_warp = 2 * granularity;
-    constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp.
+    constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
 
-    const int i0 = (threadIdx.y / ntx) * (ntx*mma_C::I);
-#ifdef INT8_MMA_AVAILABLE
-    static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
-#endif // INT8_MMA_AVAILABLE
+    const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
+#ifdef NEW_MMA_AVAILABLE
+    static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
+#endif // NEW_MMA_AVAILABLE
 
 #pragma unroll
-    for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) {
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
 #pragma unroll
         for (int n = 0; n < ntx; ++n) {
 #pragma unroll
-            for (int l = 0; l < mma_C::ne; ++l) {
-                const int j = j0 + (threadIdx.y % ntx) * mma_C::J + mma_C::get_j(l);
+            for (int l = 0; l < tile_C::ne; ++l) {
+                const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
 
                 if (j > j_max) {
                     continue;
                 }
 
-                const int i = i0 + n*mma_C::I + mma_C::get_i(l);
+                const int i = i0 + n*tile_C::I + tile_C::get_i(l);
 
                 if (need_check && i > i_max) {
                     continue;
                 }
 
-                dst[j*stride + i] = sum[(j0/mma_C::J + n)*mma_C::ne + l];
+                dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
             }
         }
     }
@@ -2505,13 +2512,13 @@ static __device__ void mul_mat_q_process_tile(
     int * tile_y = (int *) data_mul_mat_q;
     int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
 
-#ifdef INT8_MMA_AVAILABLE
+#ifdef NEW_MMA_AVAILABLE
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits::vec_dot_mma;
     constexpr mmq_write_back_t write_back = mmq_write_back_mma;
 #else
     constexpr vec_dot_mmq_t    vec_dot    = mmq_type_traits::vec_dot_dp4a;
     constexpr mmq_write_back_t write_back = mmq_write_back_dp4a;
-#endif // INT8_MMA_AVAILABLE
+#endif // NEW_MMA_AVAILABLE
 
     constexpr int blocks_per_iter = MMQ_ITER_K / qk;
 
@@ -2643,7 +2650,7 @@ static __global__ void mul_mat_q(
     const int jt =  kbc /    (blocks_per_ne00*nty);
     const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
 
-    constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
+    constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     mul_mat_q_process_tile
         (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
             it, jt, kb0_start, kb0_stop);
@@ -2749,7 +2756,7 @@ template
 static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
     const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
-    const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+    const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
     const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
     return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
 }
@@ -2825,7 +2832,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
     const int mmq_x_max = get_mmq_x_max_host(cc);
     const int mmq_y = get_mmq_y_host(cc);
     const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
-    const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
+    const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
 
     int mmq_x_best  = 0;
     int nparts_best = INT_MAX;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
index ac45f2d1..f89ed03b 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
@@ -1,25 +1,29 @@
+#include "ggml.h"
 #include "common.cuh"
 #include "mmv.cuh"
 
 template 
 static __global__ void mul_mat_vec(
         const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
-        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
-    const int64_t row     = blockIdx.x;
-    const int64_t channel = blockIdx.z;
-    const int     tid     = threadIdx.x;
+        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
+    const int64_t row       = blockIdx.x;
+    const int64_t channel   = blockIdx.y;
+    const int64_t sample    = blockIdx.z;
+    const int     tid       = threadIdx.x;
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-    x   += (channel/channel_ratio)*stride_channel_x + row*stride_row;
-    y   +=  channel               *stride_channel_y;
-    dst +=  channel               *stride_channel_dst;
+    x   +=  (sample/sample_ratio)*stride_sample_x   + (channel/channel_ratio)*stride_channel_x + row*stride_row;
+    y   +=   sample              *stride_sample_y   +  channel               *stride_channel_y;
+    dst +=   sample              *stride_sample_dst +  channel               *stride_channel_dst;
 
     const float2 * y2 = (const float2 *) y;
 
     extern __shared__ char data_mmv[];
     float * buf_iw = (float *) data_mmv;
 
-    if (block_size > WARP_SIZE) {
-        if (tid < WARP_SIZE) {
+    if (block_size > warp_size) {
+        if (tid < warp_size) {
             buf_iw[tid] = 0.0f;
         }
         __syncthreads();
@@ -67,16 +71,16 @@ static __global__ void mul_mat_vec(
         static_assert(std::is_same::value, "unsupported type");
     }
 
-    sumf = warp_reduce_sum(sumf);
+    sumf = warp_reduce_sum(sumf);
 
-    if (block_size > WARP_SIZE) {
-        buf_iw[tid/WARP_SIZE] = sumf;
+    if (block_size > warp_size) {
+        buf_iw[tid/warp_size] = sumf;
         __syncthreads();
-        if (tid >= WARP_SIZE) {
+        if (tid >= warp_size) {
             return;
         }
         sumf = buf_iw[tid];
-        sumf = warp_reduce_sum(sumf);
+        sumf = warp_reduce_sum(sumf);
     }
 
     if (tid != 0) {
@@ -90,16 +94,28 @@ template 
 static void launch_mul_mat_vec_cuda(
         const T * x, const float * y, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         cudaStream_t stream) {
     GGML_ASSERT(ncols      % 2 == 0);
     GGML_ASSERT(stride_row % 2 == 0);
     GGML_ASSERT(nchannels_y % nchannels_x == 0);
+    GGML_ASSERT(nsamples_y  % nsamples_x  == 0);
     const int64_t channel_ratio = nchannels_y / nchannels_x;
+    const int64_t sample_ratio  = nsamples_y  / nsamples_x;
+    int device;
+    int warp_size;
 
-    int64_t block_size_best = WARP_SIZE;
-    int64_t niter_best      = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
-    for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
+    CUDA_CHECK(cudaGetDevice(&device));
+    warp_size = ggml_cuda_info().devices[device].warp_size;
+
+    int64_t block_size_best = warp_size;
+    int64_t niter_best      = (ncols + 2*warp_size - 1) / (2*warp_size);
+    int64_t max_block_size  = 256;
+    if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
+        max_block_size = 128;
+    }
+    for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
         const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
         if (niter < niter_best) {
             niter_best      = niter;
@@ -107,41 +123,49 @@ static void launch_mul_mat_vec_cuda(
         }
     }
 
-    const int smem = WARP_SIZE*sizeof(float);
-    const dim3 block_nums(nrows, 1, nchannels_y);
+    const int smem = warp_size*sizeof(float);
+    const dim3 block_nums(nrows, nchannels_y, nsamples_y);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   64: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   96: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  128: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  160: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  192: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  224: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  256: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -153,16 +177,19 @@ template
 static void mul_mat_vec_cuda(
         const T * x, const float * y, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         enum ggml_prec prec, cudaStream_t stream) {
     switch (prec) {
         case GGML_PREC_DEFAULT: {
-            launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
-                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+            launch_mul_mat_vec_cuda
+                (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case GGML_PREC_F32: {
-            launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
-                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+            launch_mul_mat_vec_cuda
+                (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
     }
 }
@@ -171,10 +198,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
-    GGML_ASSERT(src1->ne[1] == 1);
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
+
+    GGML_ASSERT(ne11 == 1);
+    GGML_ASSERT(ne12 == ne2);
+    GGML_ASSERT(ne13 == ne3);
+
+    GGML_ASSERT(nb00 == ts_src0);
+    GGML_ASSERT(nb10 == ts_src1);
+    GGML_ASSERT(nb0  == ts_dst);
 
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
@@ -182,29 +218,22 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const float * src1_d = (const float *) src1->data;
     float       *  dst_d = (float       *)  dst->data;
 
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne12 = src1->ne[2];
-    GGML_ASSERT(dst->ne[2] == ne12);
-
-    GGML_ASSERT(src0->ne[3] == 1);
-    GGML_ASSERT(src1->ne[3] == 1);
-    GGML_ASSERT( dst->ne[3] == 1);
-
-    const int64_t stride_row         = src0->nb[1] / ggml_type_size(src0->type);
-    const int64_t channel_stride_x   = src0->nb[2] / ggml_type_size(src0->type);
-    const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type);
-    const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type);
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s12 = src1->nb[2] / ts_src1;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s13 = src1->nb[3] / ts_src1;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
 
     switch (src0->type) {
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
-                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
-                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -233,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec(
     const int64_t stride_row         = ne00;
     const int64_t nchannels_x        = 1;
     const int64_t nchannels_y        = 1;
-    const int64_t channel_stride_x   = 0;
-    const int64_t channel_stride_y   = 0;
-    const int64_t channel_stride_dst = 0;
+    const int64_t stride_channel_x   = 0;
+    const int64_t stride_channel_y   = 0;
+    const int64_t stride_channel_dst = 0;
+    const int64_t nsamples_x         = 1;
+    const int64_t nsamples_y         = 1;
+    const int64_t stride_sample_x    = 0;
+    const int64_t stride_sample_y    = 0;
+    const int64_t stride_sample_dst  = 0;
 
     switch (src0->type) {
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
             mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+                nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
             mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+                nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
index e3b912d8..4fb466ca 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
@@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
     int64_t nwarps = 1;
     int64_t rows_per_cuda_block = 1;
 
-    if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
+    if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
         switch(ncols_y) {
             case 1:
                 nwarps = 4;
@@ -166,6 +166,7 @@ static void mul_mat_vec_q_cuda(
                 break;
         }
     }
+
     const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
     const dim3 block_nums(nblocks, 1, 1);
     const dim3 block_dims(WARP_SIZE, nwarps, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
index 133e219f..f127616e 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
@@ -1,24 +1,36 @@
 #include "norm.cuh"
+#include 
 
 template 
-static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-    const int tid = threadIdx.x;
+static __global__ void norm_f32(
+        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+        const int64_t stride_sample, const float eps) {
+    const int nrows     = gridDim.x;
+    const int nchannels = gridDim.y;
 
-    float2 mean_var = make_float2(0.f, 0.f);
+    const int row       = blockIdx.x;
+    const int channel   = blockIdx.y;
+    const int sample    = blockIdx.z;
+    const int tid       = threadIdx.x;
+
+    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
+    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
+    float2 mean_var = make_float2(0.0f, 0.0f);
 
     for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row*ncols + col];
+        const float xi = x[col];
         mean_var.x += xi;
         mean_var.y += xi * xi;
     }
 
     // sum up partial sums
     mean_var = warp_reduce_sum(mean_var);
-    if (block_size > WARP_SIZE) {
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
         __shared__ float2 s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = mean_var;
         }
@@ -32,7 +44,7 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols, c
     const float inv_std = rsqrtf(var + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
+        dst[col] = (x[col] - mean) * inv_std;
     }
 }
 
@@ -40,14 +52,8 @@ template 
 static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
     // blockIdx.x: num_groups idx
     // threadIdx.x: block_size idx
-    int start = blockIdx.x * group_size;
-    int end = start + group_size;
-
-    start += threadIdx.x;
-
-    if (end >= ne_elements) {
-        end = ne_elements;
-    }
+    const int start =     blockIdx.x*group_size + threadIdx.x;
+    const int end   = min(blockIdx.x*group_size + group_size,  ne_elements);
 
     float tmp = 0.0f; // partial sum for thread in warp
 
@@ -56,10 +62,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
     }
 
     tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
         __shared__ float s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = tmp;
         }
@@ -68,11 +75,11 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp = warp_reduce_sum(tmp);
     }
 
-    float mean = tmp / group_size;
+    const float mean = tmp / group_size;
     tmp = 0.0f;
 
     for (int j = start; j < end; j += block_size) {
-        float xi = x[j] - mean;
+        const float xi = x[j] - mean;
         dst[j] = xi;
         tmp += xi * xi;
     }
@@ -80,8 +87,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
     tmp = warp_reduce_sum(tmp);
     if (block_size > WARP_SIZE) {
         __shared__ float s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = tmp;
         }
@@ -90,31 +97,42 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
         tmp = warp_reduce_sum(tmp);
     }
 
-    float variance = tmp / group_size;
-    float scale = rsqrtf(variance + eps);
+    const float variance = tmp / group_size;
+    const float scale = rsqrtf(variance + eps);
     for (int j = start; j < end; j += block_size) {
         dst[j] *= scale;
     }
 }
 
 template 
-static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
-    const int row = blockIdx.x*blockDim.y + threadIdx.y;
-    const int tid = threadIdx.x;
+static __global__ void rms_norm_f32(
+        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+        const int64_t stride_sample, const float eps) {
+    const int nrows     = gridDim.x;
+    const int nchannels = gridDim.y;
+
+    const int row       = blockIdx.x;
+    const int channel   = blockIdx.y;
+    const int sample    = blockIdx.z;
+    const int tid       = threadIdx.x;
+
+    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
+    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
 
     float tmp = 0.0f; // partial sum for thread in warp
 
     for (int col = tid; col < ncols; col += block_size) {
-        const float xi = x[row*ncols + col];
+        const float xi = x[col];
         tmp += xi * xi;
     }
 
     // sum up partial sums
     tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
         __shared__ float s_sum[32];
-        int warp_id = threadIdx.x / WARP_SIZE;
-        int lane_id = threadIdx.x % WARP_SIZE;
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
         if (lane_id == 0) {
             s_sum[warp_id] = tmp;
         }
@@ -127,22 +145,77 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
     const float scale = rsqrtf(mean + eps);
 
     for (int col = tid; col < ncols; col += block_size) {
-        dst[row*ncols + col] = scale * x[row*ncols + col];
+        dst[col] = scale * x[col];
     }
 }
 
-static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
+template 
+static __global__ void rms_norm_back_f32(
+        const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    grad += int64_t(row)*ncols;
+    xf   += int64_t(row)*ncols;
+    dst  += int64_t(row)*ncols;
+
+    float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
+    float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
+
+    for (int col = tid; col < ncols; col += block_size) {
+        const float xfi = xf[col];
+        sum_xx += xfi * xfi;
+        sum_xg += xfi * grad[col];
+    }
+
+    // sum up partial sums
+    sum_xx = warp_reduce_sum(sum_xx);
+    sum_xg = warp_reduce_sum(sum_xg);
+    if constexpr (block_size > WARP_SIZE) {
+        static_assert(block_size == 1024, "unexpected block_size");
+        __shared__ float s_sum_xx[32];
+        __shared__ float s_sum_xg[32];
+        const int warp_id = threadIdx.x / WARP_SIZE;
+        const int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum_xx[warp_id] = sum_xx;
+            s_sum_xg[warp_id] = sum_xg;
+        }
+        __syncthreads();
+
+        sum_xx = s_sum_xx[lane_id];
+        sum_xx = warp_reduce_sum(sum_xx);
+
+        sum_xg = s_sum_xg[lane_id];
+        sum_xg = warp_reduce_sum(sum_xg);
+    }
+
+    const float mean_eps = sum_xx / ncols + eps;
+    const float sum_eps  = sum_xx + ncols*eps;
+
+    const float scale_grad = rsqrtf(mean_eps);
+    const float scale_x    = -scale_grad * sum_xg/sum_eps;
+
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[col] = scale_grad*grad[col] + scale_x*xf[col];
+    }
+}
+
+static void norm_f32_cuda(
+        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+    const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
-        norm_f32<<>>(x, dst, ncols, eps);
+        norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        norm_f32<1024><<>>(x, dst, ncols, eps);
+        norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     }
 }
 
-static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
+static void group_norm_f32_cuda(
+        const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
     if (group_size < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
         group_norm_f32<<>>(x, dst, group_size, ne_elements, eps);
@@ -152,35 +225,51 @@ static void group_norm_f32_cuda(const float * x, float * dst, const int num_grou
     }
 }
 
-static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
-    GGML_ASSERT(ncols % WARP_SIZE == 0);
+static void rms_norm_f32_cuda(
+        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+    const dim3 blocks_num(nrows, nchannels, nsamples);
     if (ncols < 1024) {
         const dim3 block_dims(WARP_SIZE, 1, 1);
-        rms_norm_f32<<>>(x, dst, ncols, eps);
+        rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
     } else {
         const dim3 block_dims(1024, 1, 1);
-        rms_norm_f32<1024><<>>(x, dst, ncols, eps);
+        rms_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+    }
+}
+
+static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        rms_norm_back_f32<<>>(grad, xf, dst, ncols, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        rms_norm_back_f32<1024><<>>(grad, xf, dst, ncols, eps);
     }
 }
 
 void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
 
-    norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
+
+    norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
 }
 
 void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -189,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
@@ -198,6 +285,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
     float eps;
     memcpy(&eps, dst->op_params + 1, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
 
     int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
     group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
@@ -205,20 +293,50 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
+    const float * src0_d = (const float *) src0->data;
+    float * dst_d = (float *) dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    GGML_TENSOR_UNARY_OP_LOCALS;
 
     float eps;
     memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
 
-    rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+    const size_t ts0 = ggml_type_size(src0->type);
+    GGML_ASSERT(nb00 == ts0);
+    const int64_t s01 = nb01 / ts0;
+    const int64_t s02 = nb02 / ts0;
+    const int64_t s03 = nb03 / ts0;
+
+    rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
+}
+
+void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * grad  = dst->src[0]; // gradients
+    const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
+
+    const float * grad_d  = (const float *) grad->data;
+    const float * src0f_d = (const float *) src0f->data;
+    float       * dst_d   = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(grad));
+
+    GGML_ASSERT( grad->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0f->type == GGML_TYPE_F32);
+    GGML_ASSERT(  dst->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0f->ne[0];
+    const int64_t nrows = ggml_nrows(src0f);
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+    GGML_ASSERT(eps >= 0.0f);
+
+    rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
index 431a8f74..d63d3438 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
@@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/out-prod.cu b/ml/backend/ggml/ggml/src/ggml-cuda/out-prod.cu
index 619cfdcb..c9b2b699 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/out-prod.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/out-prod.cu
@@ -11,16 +11,15 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
 
     GGML_ASSERT(ne01 == ne11);
     GGML_ASSERT(ne0 == ne00);
     GGML_ASSERT(ne1 == ne10);
 
-    GGML_ASSERT(ne2 == src0->ne[2]);
+    GGML_ASSERT(ne2 % src0->ne[2] == 0);
+    GGML_ASSERT(ne3 % src0->ne[3] == 0);
+
     GGML_ASSERT(ne2 == src1->ne[2]);
-    GGML_ASSERT(ne3 == src0->ne[3]);
     GGML_ASSERT(ne3 == src1->ne[3]);
 
     const float * src0_d = (const float *) src0->data;
@@ -33,19 +32,37 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const float alpha = 1.0f;
     const float beta = 0.0f;
 
-    GGML_ASSERT(ne2 == 1);
-    GGML_ASSERT(ne3 == 1);
     CUBLAS_CHECK(cublasSetStream(handle, stream));
 
+    const int64_t lda = nb01 / sizeof(float);
+    const int64_t ldc = nb1  / sizeof(float);
+
     const bool src1_T = ggml_is_transposed(src1);
     const cublasOperation_t src1_cublas_op =  src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
     const int64_t           ldb            = (src1_T ?        nb10 :        nb11) /  sizeof(float);
     GGML_ASSERT(                             (src1_T ?        nb11 :        nb10) == sizeof(float));
 
-    CUBLAS_CHECK(
-        cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
-                ne0, ne1, ne01,
-                &alpha, src0_d, ne00,
-                        src1_d, ldb,
-                &beta,  dst_d,  ne0));
+    // data strides in dimensions 2/3
+    const size_t s02 = nb02 / sizeof(float);
+    const size_t s03 = nb03 / sizeof(float);
+    const size_t s12 = nb12 / sizeof(float);
+    const size_t s13 = nb13 / sizeof(float);
+    const size_t s2  = nb2  / sizeof(float);
+    const size_t s3  = nb3  / sizeof(float);
+
+    // dps == dst per src0, used for group query attention
+    const int64_t dps2 = ne2 / ne02;
+    const int64_t dps3 = ne3 / ne03;
+
+    // TODO batched matrix multiplication
+    for (int64_t i3 = 0; i3 < ne3; ++i3) {
+        for (int64_t i2 = 0; i2 < ne2; ++i2) {
+            CUBLAS_CHECK(
+                cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
+                        ne0, ne1, ne01,
+                        &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
+                                src1_d +  i3      *s13 +  i2      *s12, ldb,
+                        &beta,  dst_d  +  i3      *s3  +  i2      *s2,  ldc));
+        }
+    }
 }
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
index 39fd4b16..b4b87409 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
@@ -92,4 +92,4 @@ void ggml_cuda_op_unpad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     unpad_f32_cuda(src0_d, dst_d,
         src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
         dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
-}
+}
\ No newline at end of file
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu
index 2c84778d..18f691b2 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cu
@@ -16,9 +16,10 @@ static __device__ float rope_yarn_ramp(const float low, const float high, const
 
 // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+template
 static __device__ void rope_yarn(
-    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
-    float * cos_theta, float * sin_theta) {
+        const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
+        float mscale, float & cos_theta, float & sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
     float theta = theta_interp;
@@ -29,24 +30,28 @@ static __device__ void rope_yarn(
         // Get n-d magnitude scaling corrected for interpolation
         mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
     }
-    *cos_theta = cosf(theta) * mscale;
-    *sin_theta = sinf(theta) * mscale;
+    cos_theta = cosf(theta) * mscale;
+    sin_theta = sinf(theta) * mscale;
+    if (!forward) {
+        sin_theta *= -1.0f;
+    }
 }
 
-template
+template
 static __global__ void rope_norm(
-    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row_dst*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -54,39 +59,43 @@ static __global__ void rope_norm(
         return;
     }
 
-    const int i  = row*ne0 + i0;
-    const int i2 = row/p_delta_rows;
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
 
-    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+    const int idst = row_dst*ne0 + i0;
+    const int ix   = channel_x*s2 + row_x*s1 + i0;
+
+    const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + 1];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + 1];
 
-    dst[i + 0] = x0*cos_theta - x1*sin_theta;
-    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0] = x0*cos_theta - x1*sin_theta;
+    dst[idst + 1] = x0*sin_theta + x1*cos_theta;
 }
 
-template
+template
 static __global__ void rope_neox(
-    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row_dst*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -94,39 +103,43 @@ static __global__ void rope_neox(
         return;
     }
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows;
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
 
-    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+    const int idst = row_dst*ne0 + i0/2;
+    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+
+    const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims/2];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims/2];
 
-    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-template
+template
 static __global__ void rope_multi(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
+        const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row_dst*ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -134,25 +147,28 @@ static __global__ void rope_multi(
         return;
     }
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows;
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
 
-    int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
-    int sec_w = sections.v[1] + sections.v[0];
-    int sector = (i0 / 2) % sect_dims;
+    const int idst = row_dst*ne0 + i0/2;
+    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+
+    const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
+    const int sec_w = sections.v[1] + sections.v[0];
+    const int sector = (i0 / 2) % sect_dims;
 
     float theta_base = 0.0;
     if (sector < sections.v[0]) {
-        theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
     }
     else if (sector >= sections.v[0] && sector < sec_w) {
-        theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
     }
     else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
-        theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
     }
     else if (sector >= sec_w + sections.v[2]) {
-        theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
+        theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
     }
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -160,42 +176,46 @@ static __global__ void rope_multi(
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims/2];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims/2];
 
-    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0]        = x0*cos_theta - x1*sin_theta;
+    dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
 }
 
-template
+template
 static __global__ void rope_vision(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
+        const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
+        const float theta_scale, const float * freq_factors, const mrope_sections sections) {
     const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
+    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows; // i2-th tokens
+    const int row_x     = row_dst % ne1;
+    const int channel_x = row_dst / ne1;
 
-    int sect_dims = sections.v[0] + sections.v[1];
-    int sec_w = sections.v[1] + sections.v[0];
-    int sector = (i0 / 2) % sect_dims;
+    const int idst = row_dst*ne0 + i0/2;
+    const int ix   = channel_x*s2 + row_x*s1 + i0/2;
+
+    const int sect_dims = sections.v[0] + sections.v[1];
+    const int sec_w = sections.v[1] + sections.v[0];
+    const int sector = (i0 / 2) % sect_dims;
 
     float theta_base = 0.0;
     if (sector < sections.v[0]) {
         const int p = sector;
-        theta_base = pos[i2]*powf(theta_scale, p);
+        theta_base = pos[channel_x]*powf(theta_scale, p);
     }
     else if (sector >= sections.v[0] && sector < sec_w) {
         const int p = sector - sections.v[0];
-        theta_base = pos[i2 + ne2]*powf(theta_scale, p);
+        theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
     }
 
     const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@@ -203,19 +223,20 @@ static __global__ void rope_vision(
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims];
+    const float x0 = x[ix + 0];
+    const float x1 = x[ix + n_dims];
 
-    dst[i + 0]      = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
+    dst[idst + 0]      = x0*cos_theta - x1*sin_theta;
+    dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
 }
 
-template
+template
 static void rope_norm_cuda(
-    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -224,22 +245,21 @@ static void rope_norm_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_norm<<>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_norm<<>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     } else {
-        rope_norm<<>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_norm<<>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     }
 }
 
-template
+template
 static void rope_neox_cuda(
-    const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -248,22 +268,21 @@ static void rope_neox_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_neox<<>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_neox<<>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     } else {
-        rope_neox<<>>(
-                x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors
-                );
+        rope_neox<<>>(
+            x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors);
     }
 }
 
-template
+template
 static void rope_multi_cuda(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -272,22 +291,21 @@ static void rope_multi_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_multi<<>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_multi<<>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     } else {
-        rope_multi<<>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_multi<<>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     }
 }
 
-template
+template
 static void rope_vision_cuda(
-    const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
+        const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
+        const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
+        const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
     const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
@@ -298,80 +316,18 @@ static void rope_vision_cuda(
     const float theta_scale = powf(freq_base, -2.0f/n_dims);
 
     if (freq_factors == nullptr) {
-        rope_vision<<>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_vision<<>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     } else {
-        rope_vision<<>>(
-                x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
-                theta_scale, freq_factors, sections
-                );
+        rope_vision<<>>(
+            x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
+            attn_factor, corr_dims, theta_scale, freq_factors, sections);
     }
 }
 
-static void rope_norm_cuda_f16(
-    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
-
-    rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_norm_cuda_f32(
-    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
-
-    rope_norm_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_neox_cuda_f16(
-    const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
-
-    rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_neox_cuda_f32(
-    const float * x, float * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream
-) {
-
-    rope_neox_cuda(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
-}
-
-static void rope_multi_cuda_f16(
-    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-static void rope_multi_cuda_f32(
-    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_multi_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-static void rope_vision_cuda_f16(
-    const half * x, half * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-static void rope_vision_cuda_f32(
-    const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
-) {
-
-    rope_vision_cuda(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
-}
-
-void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+template 
+void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
     const ggml_tensor * src2 = dst->src[2];
@@ -382,7 +338,6 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
     GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
     GGML_ASSERT(src0->type == dst->type);
@@ -392,6 +347,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t ne02 = src0->ne[2]; // num heads
     const int64_t nr = ggml_nrows(src0);
 
+    const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
+    const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
+
     //const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
     const int mode       = ((int32_t *) dst->op_params)[2];
@@ -440,59 +398,59 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     // compute
     if (is_neox) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_neox_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_neox_cuda(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_neox_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_neox_cuda(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_mrope && !is_vision) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_multi_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_multi_cuda(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_multi_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_multi_cuda(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_vision) {
         if (src0->type == GGML_TYPE_F32) {
-            rope_vision_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_vision_cuda(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_vision_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, sections, stream
-            );
+            rope_vision_cuda(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else {
         if (src0->type == GGML_TYPE_F32) {
-            rope_norm_cuda_f32(
-                (const float *)src0_d, (float *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_norm_cuda(
+                (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else if (src0->type == GGML_TYPE_F16) {
-            rope_norm_cuda_f16(
-                (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, stream
-            );
+            rope_norm_cuda(
+                (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale,
+                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
         } else {
             GGML_ABORT("fatal error");
         }
     }
 }
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_rope_impl(ctx, dst);
+}
+
+void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_op_rope_impl(ctx, dst);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cuh
index 0f787a0b..9139f3b2 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/rope.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/rope.cuh
@@ -3,3 +3,5 @@
 #define CUDA_ROPE_BLOCK_SIZE 256
 
 void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu
index c24abae1..aac6e099 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cu
@@ -1,5 +1,7 @@
 #include "common.cuh"
+#include "ggml.h"
 #include "softmax.cuh"
+#include 
 
 template 
 static __device__ __forceinline__ float t2f32(T val) {
@@ -11,14 +13,26 @@ __device__ float __forceinline__ t2f32(half val) {
     return __half2float(val);
 }
 
-template 
-static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
+// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+template 
+static __global__ void soft_max_f32(
+        const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
+        const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
     const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
     const int tid  = threadIdx.x;
     const int rowx = blockIdx.x;
     const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
 
+    x    += int64_t(rowx)*ncols;
+    mask += int64_t(rowy)*ncols * (mask != nullptr);
+    dst  += int64_t(rowx)*ncols;
+
     const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
 
     const int warp_id = threadIdx.x / WARP_SIZE;
@@ -29,7 +43,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
     extern __shared__ float data_soft_max_f32[];
     float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
     // shared memory buffer to cache values between iterations:
-    float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + (int64_t)rowx*ncols;
+    float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
 
     float max_val = -INFINITY;
 
@@ -41,10 +55,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
             break;
         }
 
-        const int64_t ix = (int64_t)rowx*ncols + col;
-        const int64_t iy = (int64_t)rowy*ncols + col;
-
-        const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
+        const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
 
         vals[col] = val;
         max_val = max(max_val, val);
@@ -110,8 +121,32 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
             return;
         }
 
-        const int64_t idst = (int64_t)rowx*ncols + col;
-        dst[idst] = vals[col] * inv_sum;
+        dst[col] = vals[col] * inv_sum;
+    }
+}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
+
+static __global__ void soft_max_back_f32(
+        const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
+    const int tid  = threadIdx.x;
+    const int rowx = blockIdx.x;
+
+    grad += int64_t(rowx)*ncols;
+    dstf += int64_t(rowx)*ncols;
+    dst  += int64_t(rowx)*ncols;
+
+    float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
+
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
+        dgf_dot += dstf[col]*grad[col];
+    }
+
+    dgf_dot = warp_reduce_sum(dgf_dot);
+
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
+        dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
     }
 }
 
@@ -121,7 +156,7 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
     while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
     const dim3 block_dims(nth,     1, 1);
     const dim3 block_nums(nrows_x, 1, 1);
-    const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
+    const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
     static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
 
     const uint32_t n_head      = nrows_x/nrows_y;
@@ -131,50 +166,68 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
     // FIXME: this limit could be raised by ~2-4x on Ampere or newer
-    if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
+    if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
         switch (ncols_x) {
             case 32:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 64:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 128:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 256:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 512:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 1024:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 2048:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 4096:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             default:
-                soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<<>>
+                    (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
         }
     } else {
-        const size_t shmem_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+        const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
+        soft_max_f32<<>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
     }
 }
 
+static void soft_max_back_f32_cuda(
+        const float * grad, const float * dstf, float * dst,
+        const int ncols, const int nrows, const float scale, cudaStream_t stream) {
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    const dim3 block_nums(nrows,     1, 1);
+
+    soft_max_back_f32<<>>(grad, dstf, dst, ncols, scale);
+}
+
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
 
-    const float * src0_d = (const float *)src0->data;
-    const void  * src1_d = src1 ? (const void *)src1->data : nullptr;
+    const float * src0_d = (const float *) src0->data;
+    const void  * src1_d = src1 ? (const void *) src1->data : nullptr;
+    float       *  dst_d = (float *) dst->data;
 
-    float * dst_d = (float *)dst->data;
     cudaStream_t stream = ctx.stream();
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -189,18 +242,42 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     float scale    = 1.0f;
     float max_bias = 0.0f;
 
-    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
 
     const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
     if (use_f16) {
-        const half * src1_dd = (const half *)src1_d;
-
-        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+        soft_max_f32_cuda(src0_d, (const half  *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
     } else {
-        const float * src1_dd = (const float *)src1_d;
-
-        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+        soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
     }
 }
+
+void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // grad
+    const ggml_tensor * src1 = dst->src[1]; // forward pass output
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
+
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+    GGML_ASSERT(max_bias == 0.0f);
+
+    soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
+}
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cuh
index 4ef4ff86..93dfee83 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/softmax.cuh
@@ -3,3 +3,5 @@
 #define CUDA_SOFT_MAX_BLOCK_SIZE 1024
 
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu b/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu
index e0dafc1d..f9589080 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/sum.cu
@@ -1,6 +1,6 @@
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
 #define USE_CUB
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
 
 #ifdef USE_CUB
 #include 
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
new file mode 100644
index 00000000..80108615
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 1, 8);
+DECL_FATTN_MMA_F16_CASE(80, 1, 8);
+DECL_FATTN_MMA_F16_CASE(96, 1, 8);
+DECL_FATTN_MMA_F16_CASE(112, 1, 8);
+DECL_FATTN_MMA_F16_CASE(128, 1, 8);
+DECL_FATTN_MMA_F16_CASE(256, 1, 8);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
new file mode 100644
index 00000000..66161c0a
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 1);
+DECL_FATTN_MMA_F16_CASE(80, 16, 1);
+DECL_FATTN_MMA_F16_CASE(96, 16, 1);
+DECL_FATTN_MMA_F16_CASE(112, 16, 1);
+DECL_FATTN_MMA_F16_CASE(128, 16, 1);
+DECL_FATTN_MMA_F16_CASE(256, 16, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
new file mode 100644
index 00000000..ee88c72a
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 2);
+DECL_FATTN_MMA_F16_CASE(80, 16, 2);
+DECL_FATTN_MMA_F16_CASE(96, 16, 2);
+DECL_FATTN_MMA_F16_CASE(112, 16, 2);
+DECL_FATTN_MMA_F16_CASE(128, 16, 2);
+DECL_FATTN_MMA_F16_CASE(256, 16, 2);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
new file mode 100644
index 00000000..d888a5a4
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 4);
+DECL_FATTN_MMA_F16_CASE(80, 16, 4);
+DECL_FATTN_MMA_F16_CASE(96, 16, 4);
+DECL_FATTN_MMA_F16_CASE(112, 16, 4);
+DECL_FATTN_MMA_F16_CASE(128, 16, 4);
+DECL_FATTN_MMA_F16_CASE(256, 16, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
new file mode 100644
index 00000000..d93a2d08
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 2, 4);
+DECL_FATTN_MMA_F16_CASE(80, 2, 4);
+DECL_FATTN_MMA_F16_CASE(96, 2, 4);
+DECL_FATTN_MMA_F16_CASE(112, 2, 4);
+DECL_FATTN_MMA_F16_CASE(128, 2, 4);
+DECL_FATTN_MMA_F16_CASE(256, 2, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
new file mode 100644
index 00000000..617464c9
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 2, 8);
+DECL_FATTN_MMA_F16_CASE(80, 2, 8);
+DECL_FATTN_MMA_F16_CASE(96, 2, 8);
+DECL_FATTN_MMA_F16_CASE(112, 2, 8);
+DECL_FATTN_MMA_F16_CASE(128, 2, 8);
+DECL_FATTN_MMA_F16_CASE(256, 2, 8);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
new file mode 100644
index 00000000..970d2b68
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 32, 1);
+DECL_FATTN_MMA_F16_CASE(80, 32, 1);
+DECL_FATTN_MMA_F16_CASE(96, 32, 1);
+DECL_FATTN_MMA_F16_CASE(112, 32, 1);
+DECL_FATTN_MMA_F16_CASE(128, 32, 1);
+DECL_FATTN_MMA_F16_CASE(256, 32, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
new file mode 100644
index 00000000..65cd377c
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 32, 2);
+DECL_FATTN_MMA_F16_CASE(80, 32, 2);
+DECL_FATTN_MMA_F16_CASE(96, 32, 2);
+DECL_FATTN_MMA_F16_CASE(112, 32, 2);
+DECL_FATTN_MMA_F16_CASE(128, 32, 2);
+DECL_FATTN_MMA_F16_CASE(256, 32, 2);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
new file mode 100644
index 00000000..f4a8bf34
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 2);
+DECL_FATTN_MMA_F16_CASE(80, 4, 2);
+DECL_FATTN_MMA_F16_CASE(96, 4, 2);
+DECL_FATTN_MMA_F16_CASE(112, 4, 2);
+DECL_FATTN_MMA_F16_CASE(128, 4, 2);
+DECL_FATTN_MMA_F16_CASE(256, 4, 2);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
new file mode 100644
index 00000000..de191a8a
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 4);
+DECL_FATTN_MMA_F16_CASE(80, 4, 4);
+DECL_FATTN_MMA_F16_CASE(96, 4, 4);
+DECL_FATTN_MMA_F16_CASE(112, 4, 4);
+DECL_FATTN_MMA_F16_CASE(128, 4, 4);
+DECL_FATTN_MMA_F16_CASE(256, 4, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
new file mode 100644
index 00000000..e8cb0e1b
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 8);
+DECL_FATTN_MMA_F16_CASE(80, 4, 8);
+DECL_FATTN_MMA_F16_CASE(96, 4, 8);
+DECL_FATTN_MMA_F16_CASE(112, 4, 8);
+DECL_FATTN_MMA_F16_CASE(128, 4, 8);
+DECL_FATTN_MMA_F16_CASE(256, 4, 8);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
new file mode 100644
index 00000000..a532e962
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 1);
+DECL_FATTN_MMA_F16_CASE(80, 64, 1);
+DECL_FATTN_MMA_F16_CASE(96, 64, 1);
+DECL_FATTN_MMA_F16_CASE(112, 64, 1);
+DECL_FATTN_MMA_F16_CASE(128, 64, 1);
+DECL_FATTN_MMA_F16_CASE(256, 64, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
new file mode 100644
index 00000000..bf25181a
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 1);
+DECL_FATTN_MMA_F16_CASE(80, 8, 1);
+DECL_FATTN_MMA_F16_CASE(96, 8, 1);
+DECL_FATTN_MMA_F16_CASE(112, 8, 1);
+DECL_FATTN_MMA_F16_CASE(128, 8, 1);
+DECL_FATTN_MMA_F16_CASE(256, 8, 1);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
new file mode 100644
index 00000000..378c132e
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 2);
+DECL_FATTN_MMA_F16_CASE(80, 8, 2);
+DECL_FATTN_MMA_F16_CASE(96, 8, 2);
+DECL_FATTN_MMA_F16_CASE(112, 8, 2);
+DECL_FATTN_MMA_F16_CASE(128, 8, 2);
+DECL_FATTN_MMA_F16_CASE(256, 8, 2);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
new file mode 100644
index 00000000..372641be
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 4);
+DECL_FATTN_MMA_F16_CASE(80, 8, 4);
+DECL_FATTN_MMA_F16_CASE(96, 8, 4);
+DECL_FATTN_MMA_F16_CASE(112, 8, 4);
+DECL_FATTN_MMA_F16_CASE(128, 8, 4);
+DECL_FATTN_MMA_F16_CASE(256, 8, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
new file mode 100644
index 00000000..9ff5968b
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 8);
+DECL_FATTN_MMA_F16_CASE(80, 8, 8);
+DECL_FATTN_MMA_F16_CASE(96, 8, 8);
+DECL_FATTN_MMA_F16_CASE(112, 8, 8);
+DECL_FATTN_MMA_F16_CASE(128, 8, 8);
+DECL_FATTN_MMA_F16_CASE(256, 8, 8);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu
index 81fc9220..6b21f407 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu
@@ -51,6 +51,19 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void silu_back_f32(
+        const float * grad, const float * xf, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const float xfi = xf[i];
+    const float s = 1.0f / (1.0f + expf(-xfi));
+    dst[i] = grad[i] * s * (1.0f + xfi * (1.0f - s));
+}
+
 static __global__ void tanh_f32(const float * x, float * dst, int k) {
     const int i  = blockDim.x*blockIdx.x + threadIdx.x;
     if (i >= k) {
@@ -173,6 +186,11 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<>>(x, dst, k);
 }
 
+static void silu_back_f32_cuda(const float * grad, const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+    silu_back_f32<<>>(grad, x, dst, k);
+}
+
 static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
     tanh_f32<<>>(x, dst, k);
@@ -284,6 +302,24 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
 
+void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0]; // input from forward pass
+    const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
+
+    const float * src0_d = (const float *) src0->data;
+    const float * src1_d = (const float *) src1->data;
+    float       * dst_d  = (float       *) dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    silu_back_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(src0), stream);
+}
+
 void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh
index c9193672..e7f62643 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh
@@ -4,6 +4,7 @@
 #define CUDA_STEP_BLOCK_SIZE 256
 #define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_SILU_BACK_BLOCK_SIZE 256
 #define CUDA_TANH_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
 #define CUDA_SIGMOID_BLOCK_SIZE 256
@@ -23,6 +24,8 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu b/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu
index 42578341..bbdafbee 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/wkv6.cu
@@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     const float * s_d  = (const float *)dst->src[5]->data;
 
     const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[3];
+    const int64_t T = dst->src[0]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[2];
+    const int64_t H = dst->src[0]->ne[1];
 
     float * dst_d = (float *)dst->data;
 
diff --git a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
index b15fbd24..4a0384dd 100644
--- a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
+++ b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
@@ -40,13 +40,20 @@ find_package(hip     REQUIRED)
 find_package(hipblas REQUIRED)
 find_package(rocblas REQUIRED)
 
+if (${hip_VERSION} VERSION_LESS 5.5)
+    message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
+endif()
+
 message(STATUS "HIP and hipBLAS found")
 
+# Workaround old compilers
+set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
+
 file(GLOB   GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
 list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
 
 file(GLOB   GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
-file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
+file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
 file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
 list(APPEND GGML_SOURCES_ROCM ${SRCS})
@@ -70,7 +77,9 @@ ggml_add_backend_library(ggml-hip
                         )
 
 # TODO: do not use CUDA definitions for HIP
-target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
+if (NOT GGML_BACKEND_DL)
+    target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
+endif()
 
 add_compile_definitions(GGML_USE_HIP)
 
@@ -90,6 +99,18 @@ if (GGML_CUDA_NO_PEER_COPY)
     add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
 endif()
 
+if (GGML_HIP_GRAPHS)
+    add_compile_definitions(GGML_HIP_GRAPHS)
+endif()
+
+if (GGML_HIP_NO_VMM)
+    add_compile_definitions(GGML_HIP_NO_VMM)
+endif()
+
+if (NOT GGML_CUDA_FA)
+    add_compile_definitions(GGML_CUDA_NO_FA)
+endif()
+
 if (CXX_IS_HIPCC)
     set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
     target_link_libraries(ggml-hip PRIVATE hip::device)
diff --git a/ml/backend/ggml/ggml/src/ggml-impl.h b/ml/backend/ggml/ggml/src/ggml-impl.h
index 549772c5..1fbcbd04 100644
--- a/ml/backend/ggml/ggml/src/ggml-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-impl.h
@@ -3,6 +3,8 @@
 // GGML internal header
 
 #include "ggml.h"
+#include "gguf.h"
+
 #include 
 #include 
 #include  // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
@@ -14,7 +16,7 @@
 #include 
 #endif // __ARM_FEATURE_SVE
 
-#if defined(__ARM_NEON) && !defined(__CUDACC__)
+#if defined(__ARM_NEON) && !defined(__CUDACC__) && !defined(__MUSACC__)
 // if YCM cannot find , make a symbolic link to it, for example:
 //
 //   $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
@@ -551,22 +553,15 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
 #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
 #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
 
-// expose GGUF internals for test code
-
-GGML_API size_t gguf_type_size(enum gguf_type type);
-
-GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
-
-struct gguf_buf {
-    void * data;
-    size_t size;
-    size_t offset;
-};
-GGML_API struct gguf_buf gguf_buf_init(size_t size);
-GGML_API void gguf_buf_free(struct gguf_buf buf);
-
-GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta);
-
 #ifdef __cplusplus
 }
 #endif
+
+#ifdef __cplusplus
+#include 
+
+// expose GGUF internals for test code
+GGML_API size_t gguf_type_size(enum gguf_type type);
+GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
+GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta);
+#endif // __cplusplus
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
index f10966df..c3610ac0 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
@@ -477,7 +477,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
     240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
 GGML_TABLE_END()
 
-//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
 GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
     0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
@@ -512,7 +511,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
     0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
 GGML_TABLE_END()
-//#endif
 
 
 GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
@@ -2513,24 +2511,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
 template 
 void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
     const half d_all = xb->d;
-    device const uint8_t * ql = (device const uint8_t *)xb->ql;
-    device const uint8_t * qh = (device const uint8_t *)xb->qh;
+    device const uint16_t * ql = (device const uint16_t *)xb->ql;
+    device const uint16_t * qh = (device const uint16_t *)xb->qh;
     device const int8_t * scales = (device const int8_t *)xb->scales;
 
-    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
-    qh = qh + 32*(il/8) + 16*(il&1);
+    ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
+    qh = qh + 16*(il/8) + 8*(il&1);
     float sc = scales[(il%2) + 2 * ((il/2))];
     il = (il/2) & 3;
 
-    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
-    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
-    const float       coef = il>1 ? 1.f/16.f          : 1.f;
+    const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
+    const uint32_t kmask2 = il>1 ? 0xF0F0F0F0                       : 0x0F0F0F0F;
     const float ml = d_all * sc * 32.f;
-    const float dl = d_all * sc * coef;
-    for (int i = 0; i < 16; ++i) {
-        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
-                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
-        reg[i/4][i%4] = dl * q - ml;
+    const float dl0 = d_all * sc;
+    const float dl1 = dl0 / 256.f;
+    const float dl2 = dl0 / (256.f * 256.f);
+    const float dl3 = dl0 / (256.f * 256.f * 256.f);
+    const uint8_t shr_h = il>2 ? 2 : 0;
+    const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
+    const uint8_t shr_l = il>1 ? 4 : 0;
+    for (int i = 0; i < 4; ++i) {
+        const uint32_t  low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
+        const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
+        const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
+        reg[i][0] = dl0 *  ((half)(q & 0xFF))       - ml;
+        reg[i][1] = dl1 * ((float)(q & 0xFF00))     - ml;
+        reg[i][2] = dl2 * ((float)(q & 0xFF0000))   - ml;
+        reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
     }
 }
 
@@ -3198,7 +3205,7 @@ kernel void kernel_soft_max(
     }
 
     // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
     threadgroup_barrier(mem_flags::mem_none);
 
     float sum = simd_sum(lsum);
@@ -3303,7 +3310,7 @@ kernel void kernel_soft_max_4(
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
 
     // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
     threadgroup_barrier(mem_flags::mem_none);
 
     float sum = simd_sum(lsum);
@@ -6517,6 +6524,49 @@ kernel void kernel_cpy_f32_iq4_nl(
     }
 }
 
+template
+kernel void kernel_cpy_q_f32(
+        constant ggml_metal_kargs_cpy & args,
+        device  const char * src0,
+        device        char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
+
+    device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       T4x4    * dst_data = (device       T4x4    *)(dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
+        T4x4 temp;
+        dequantize_func(src_data + i00/nl, i00%nl, temp);
+        dst_data[i00] = temp;
+    }
+}
+
+typedef decltype(kernel_cpy_q_f32) cpy_q_f_t;
+
+template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+
+template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+
 kernel void kernel_concat(
     constant ggml_metal_kargs_concat & args,
     device  const char * src0,
@@ -6601,7 +6651,6 @@ void kernel_mul_mv_q2_K_f32_impl(
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
-
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
             for (int i = 0; i < 8; i += 2) {
@@ -6632,7 +6681,7 @@ void kernel_mul_mv_q2_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -6798,7 +6847,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
     if (tiisg == 0) {
-        for (int row = 0; row < 2; ++row) {
+        for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
             dst_f32[first_row + row] = sumf1[row];
         }
     }
@@ -6914,7 +6963,7 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -7046,7 +7095,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = tot;
@@ -7091,6 +7140,10 @@ void kernel_mul_mv_q6_K_f32_impl(
 
     const int row = 2*r0 + sgitg;
 
+    if (row >= args.ne0) {
+        return;
+    }
+
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
@@ -7246,7 +7299,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -7364,7 +7417,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -7474,7 +7527,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.5f;
@@ -7586,7 +7639,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -7699,7 +7752,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -7799,7 +7852,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -7894,7 +7947,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -7984,7 +8037,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -8073,7 +8126,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
index 318addec..e4c093f9 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
@@ -19,7 +19,17 @@
 // max number of MTLCommandBuffer used to submit a graph for processing
 #define GGML_METAL_MAX_COMMAND_BUFFERS 8
 
-#define UNUSED(x) (void)(x)
+#ifndef TARGET_OS_VISION
+#define TARGET_OS_VISION 0
+#endif
+
+// create residency sets only on macOS >= 15.0
+#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
+    TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
+    TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
+    TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
+#define GGML_METAL_HAS_RESIDENCY_SETS 1
+#endif
 
 // globals
 
@@ -39,6 +49,7 @@ static struct ggml_backend_metal_device_context {
 
     bool has_simdgroup_reduction;
     bool has_simdgroup_mm;
+    bool has_residency_sets;
     bool has_bfloat;
     bool use_bfloat;
 
@@ -48,6 +59,7 @@ static struct ggml_backend_metal_device_context {
     /*.mtl_device_ref_count    =*/ 0,
     /*.has_simdgroup_reduction =*/ false,
     /*.has_simdgroup_mm        =*/ false,
+    /*.has_residency_sets      =*/ false,
     /*.has_bfloat              =*/ false,
     /*.use_bfloat              =*/ false,
     /*.name                    =*/ "",
@@ -59,12 +71,18 @@ static id ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
 
     if (ctx->mtl_device == nil) {
         ctx->mtl_device = MTLCreateSystemDefaultDevice();
+    }
 
+    if (ctx->mtl_device) {
         ctx->has_simdgroup_reduction  = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
         ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
 
         ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
 
+#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
+        ctx->has_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == NULL;
+#endif
+
         ctx->has_bfloat  = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
         ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
 
@@ -90,8 +108,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
     ctx->mtl_device_ref_count--;
 
     if (ctx->mtl_device_ref_count == 0) {
-        [ctx->mtl_device release];
-        ctx->mtl_device = nil;
+        if (ctx->mtl_device) {
+            [ctx->mtl_device release];
+            ctx->mtl_device = nil;
+        }
     }
 }
 
@@ -388,6 +408,16 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,
     GGML_METAL_KERNEL_TYPE_CONCAT,
     GGML_METAL_KERNEL_TYPE_SQR,
     GGML_METAL_KERNEL_TYPE_SQRT,
@@ -484,6 +514,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
     GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
 
     ctx->queue  = [device newCommandQueue];
+    if (ctx->queue == nil) {
+        GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
+        return NULL;
+    }
+
     ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
     id metal_library;
@@ -650,6 +685,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 
     GGML_LOG_INFO("%s: simdgroup reduction   = %s\n", __func__, ctx_dev->has_simdgroup_reduction     ? "true" : "false");
     GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm            ? "true" : "false");
+    GGML_LOG_INFO("%s: has residency sets    = %s\n", __func__, ctx_dev->has_residency_sets          ? "true" : "false");
     GGML_LOG_INFO("%s: has bfloat            = %s\n", __func__, ctx_dev->has_bfloat                  ? "true" : "false");
     GGML_LOG_INFO("%s: use bfloat            = %s\n", __func__, ctx_dev->use_bfloat                  ? "true" : "false");
     GGML_LOG_INFO("%s: hasUnifiedMemory      = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
@@ -988,6 +1024,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,                  cpy_q4_0_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,                  cpy_q4_0_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,                  cpy_q4_1_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,                  cpy_q4_1_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,                  cpy_q5_0_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,                  cpy_q5_0_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,                  cpy_q5_1_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,                  cpy_q5_1_f16,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,                  cpy_q8_0_f32,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,                  cpy_q8_0_f16,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true);
@@ -1037,8 +1083,70 @@ struct ggml_backend_metal_buffer_context {
     // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
     int n_buffers;
     struct ggml_backend_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
+
+    // optional MTLResidencySet
+    id rset;
 };
 
+// rset init
+static bool ggml_backend_metal_buffer_rset_init(
+        struct ggml_backend_metal_buffer_context * ctx,
+        struct ggml_backend_metal_device_context * ctx_dev,
+        id device) {
+    ctx->rset = nil;
+
+    if (!ctx_dev->has_residency_sets) {
+        return true;
+    }
+
+#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
+    if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
+        MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
+        desc.label = @"ggml_backend_metal";
+        desc.initialCapacity = ctx->n_buffers;
+
+        NSError * error;
+        ctx->rset = [device newResidencySetWithDescriptor:desc error:&error];
+        if (error) {
+            GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+            [desc release];
+            return false;
+        }
+
+        [desc release];
+
+        for (int i = 0; i < ctx->n_buffers; i++) {
+            [ctx->rset addAllocation:ctx->buffers[i].metal];
+        }
+
+        [ctx->rset commit];
+        [ctx->rset requestResidency];
+
+        return true;
+    }
+#else
+    GGML_UNUSED(ctx_dev);
+    GGML_UNUSED(device);
+#endif
+
+    return true;
+}
+
+// rset free
+static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) {
+#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
+    if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
+        if (ctx->rset) {
+            [ctx->rset endResidency];
+            [ctx->rset removeAllAllocations];
+            [ctx->rset release];
+        }
+    }
+#else
+    GGML_UNUSED(ctx);
+#endif
+}
+
 // finds the Metal buffer that contains the tensor data on the GPU device
 // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
 // Metal buffer based on the host memory pointer
@@ -1122,12 +1230,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_SUM_ROWS:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_GROUP_NORM:
-            return has_simdgroup_reduction;
+            return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
         case GGML_OP_RMS_NORM:
-            return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
+            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ARGMAX:
-        case GGML_OP_NORM:
             return true;
+        case GGML_OP_NORM:
+            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ROPE:
             {
                 const int mode = ((const int32_t *) op->op_params)[2];
@@ -1201,6 +1310,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                             default:
                                 return false;
                         }
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_Q8_0:
+                        switch (op->type) {
+                            case GGML_TYPE_F32:
+                            case GGML_TYPE_F16:
+                                return true;
+                            default:
+                                return false;
+                        }
                     default:
                         return false;
                 };
@@ -1897,7 +2018,7 @@ static void ggml_metal_encode_node(
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
                 // TODO: add ggml_metal_kargs struct
-                // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
+                // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 if (id_src1) {
@@ -3843,10 +3964,6 @@ static void ggml_metal_encode_node(
         case GGML_OP_CPY:
         case GGML_OP_CONT:
             {
-                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-
-                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
-
                 id pipeline = nil;
 
                 switch (src0t) {
@@ -3880,7 +3997,47 @@ static void ggml_metal_encode_node(
                             switch (dstt) {
                                 case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
                                 case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
-                                default: GGML_ASSERT(false && "not implemented");
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q4_0:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q4_1:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q5_0:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q5_1:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            };
+                        } break;
+                    case GGML_TYPE_Q8_0:
+                        {
+                            switch (dstt) {
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break;
+                                default: GGML_ABORT("not implemented");
                             };
                         } break;
                     default: GGML_ABORT("not implemented");
@@ -3910,7 +4067,11 @@ static void ggml_metal_encode_node(
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 
+                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+                int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
+
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+
             } break;
         case GGML_OP_SET:
             {
@@ -4209,6 +4370,8 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
     for (int i = 0; i < ctx->n_buffers; i++) {
         [ctx->buffers[i].metal release];
     }
+
+    ggml_backend_metal_buffer_rset_free(ctx);
     ggml_backend_metal_device_rel(buffer->buft->device->context);
 
     if (ctx->owned) {
@@ -4232,19 +4395,19 @@ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
 static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
     memset((char *)tensor->data + offset, value, size);
 
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
     memcpy((char *)tensor->data + offset, data, size);
 
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
     memcpy(data, (const char *)tensor->data + offset, size);
 
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
 }
 
 static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
@@ -4254,7 +4417,7 @@ static bool ggml_backend_metal_buffer_cpy_tensor(ggml_backend_buffer_t buffer, c
     }
     return false;
 
-    UNUSED(buffer);
+    GGML_UNUSED(buffer);
 }
 
 static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@@ -4280,7 +4443,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
 static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
     return "Metal";
 
-    UNUSED(buft);
+    GGML_UNUSED(buft);
 }
 
 static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) {
@@ -4304,8 +4467,8 @@ static void ggml_backend_metal_log_allocated_size(id device, size_t s
     }
 #endif
 #endif
-    UNUSED(device);
-    UNUSED(size_aligned);
+    GGML_UNUSED(device);
+    GGML_UNUSED(size_aligned);
 }
 
 static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
@@ -4318,7 +4481,8 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
         size_aligned += (size_page - (size_aligned % size_page));
     }
 
-    id device = ggml_backend_metal_device_acq(buft->device->context);
+    struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)buft->device->context;
+    id device = ggml_backend_metal_device_acq(ctx_dev);
 
     ctx->all_data = ggml_metal_host_malloc(size_aligned);
     ctx->all_size = size_aligned;
@@ -4341,7 +4505,14 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
     if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
         GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
         free(ctx);
-        ggml_backend_metal_device_rel(buft->device->context);
+        ggml_backend_metal_device_rel(ctx_dev);
+        return NULL;
+    }
+
+    if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
+        GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
+        free(ctx);
+        ggml_backend_metal_device_rel(ctx_dev);
         return NULL;
     }
 
@@ -4352,7 +4523,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
 
 static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
     return 32;
-    UNUSED(buft);
+    GGML_UNUSED(buft);
 }
 
 static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
@@ -4362,13 +4533,13 @@ static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_ty
 
     return max_size;
 
-    UNUSED(buft);
+    GGML_UNUSED(buft);
 }
 
 static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
     return true;
 
-    UNUSED(buft);
+    GGML_UNUSED(buft);
 }
 
 ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
@@ -4391,7 +4562,7 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
 static const char * ggml_backend_metal_buffer_from_ptr_type_get_name(ggml_backend_buffer_type_t buft) {
     return "Metal_Mapped";
 
-    UNUSED(buft);
+    GGML_UNUSED(buft);
 }
 
 static ggml_backend_buffer_type_t ggml_backend_metal_buffer_from_ptr_type(void) {
@@ -4434,7 +4605,8 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
         size_aligned += (size_page - (size_aligned % size_page));
     }
 
-    id device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
+    struct ggml_backend_metal_device_context * ctx_dev = &g_ggml_ctx_dev_main;
+    id device = ggml_backend_metal_device_acq(ctx_dev);
 
     // the buffer fits into the max buffer size allowed by the device
     if (size_aligned <= device.maxBufferLength) {
@@ -4487,6 +4659,13 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
         }
     }
 
+    if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
+        GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
+        free(ctx);
+        ggml_backend_metal_device_rel(ctx_dev);
+        return NULL;
+    }
+
     return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
 }
 
@@ -4495,7 +4674,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
 static const char * ggml_backend_metal_name(ggml_backend_t backend) {
     return "Metal";
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 
 static void ggml_backend_metal_free(ggml_backend_t backend) {
@@ -4800,6 +4979,13 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_back
         }
     }
 
+    if (!ggml_backend_metal_buffer_rset_init(ctx, ctx_dev, device)) {
+        GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
+        free(ctx);
+        ggml_backend_metal_device_rel(ctx_dev);
+        return NULL;
+    }
+
     return ggml_backend_buffer_init(ggml_backend_metal_buffer_from_ptr_type(), ggml_backend_metal_buffer_i, ctx, size);
 }
 
@@ -4813,7 +4999,7 @@ static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml
     return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name ||
             buft->iface.get_name == ggml_backend_metal_buffer_from_ptr_type_get_name;
 
-    UNUSED(dev);
+    GGML_UNUSED(dev);
 }
 
 static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
index 204c93e6..f38909d0 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
@@ -373,24 +373,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
 template 
 void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
     const half d_all = xb->d;
-    device const uint8_t * ql = (device const uint8_t *)xb->ql;
-    device const uint8_t * qh = (device const uint8_t *)xb->qh;
+    device const uint16_t * ql = (device const uint16_t *)xb->ql;
+    device const uint16_t * qh = (device const uint16_t *)xb->qh;
     device const int8_t * scales = (device const int8_t *)xb->scales;
 
-    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
-    qh = qh + 32*(il/8) + 16*(il&1);
+    ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
+    qh = qh + 16*(il/8) + 8*(il&1);
     float sc = scales[(il%2) + 2 * ((il/2))];
     il = (il/2) & 3;
 
-    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
-    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
-    const float       coef = il>1 ? 1.f/16.f          : 1.f;
+    const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
+    const uint32_t kmask2 = il>1 ? 0xF0F0F0F0                       : 0x0F0F0F0F;
     const float ml = d_all * sc * 32.f;
-    const float dl = d_all * sc * coef;
-    for (int i = 0; i < 16; ++i) {
-        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
-                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
-        reg[i/4][i%4] = dl * q - ml;
+    const float dl0 = d_all * sc;
+    const float dl1 = dl0 / 256.f;
+    const float dl2 = dl0 / (256.f * 256.f);
+    const float dl3 = dl0 / (256.f * 256.f * 256.f);
+    const uint8_t shr_h = il>2 ? 2 : 0;
+    const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
+    const uint8_t shr_l = il>1 ? 4 : 0;
+    for (int i = 0; i < 4; ++i) {
+        const uint32_t  low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
+        const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
+        const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
+        reg[i][0] = dl0 *  ((half)(q & 0xFF))       - ml;
+        reg[i][1] = dl1 * ((float)(q & 0xFF00))     - ml;
+        reg[i][2] = dl2 * ((float)(q & 0xFF0000))   - ml;
+        reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
     }
 }
 
@@ -1058,7 +1067,7 @@ kernel void kernel_soft_max(
     }
 
     // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
     threadgroup_barrier(mem_flags::mem_none);
 
     float sum = simd_sum(lsum);
@@ -1163,7 +1172,7 @@ kernel void kernel_soft_max_4(
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
 
     // This barrier fixes a failing test
-    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+    // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
     threadgroup_barrier(mem_flags::mem_none);
 
     float sum = simd_sum(lsum);
@@ -4377,6 +4386,49 @@ kernel void kernel_cpy_f32_iq4_nl(
     }
 }
 
+template
+kernel void kernel_cpy_q_f32(
+        constant ggml_metal_kargs_cpy & args,
+        device  const char * src0,
+        device        char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
+
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+
+    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
+
+    device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       T4x4    * dst_data = (device       T4x4    *)(dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1 + i0*args.nb0);
+
+    for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
+        T4x4 temp;
+        dequantize_func(src_data + i00/nl, i00%nl, temp);
+        dst_data[i00] = temp;
+    }
+}
+
+typedef decltype(kernel_cpy_q_f32) cpy_q_f_t;
+
+template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+
+template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32;
+
 kernel void kernel_concat(
     constant ggml_metal_kargs_concat & args,
     device  const char * src0,
@@ -4461,7 +4513,6 @@ void kernel_mul_mv_q2_K_f32_impl(
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
-
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
             for (int i = 0; i < 8; i += 2) {
@@ -4492,7 +4543,7 @@ void kernel_mul_mv_q2_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -4658,7 +4709,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
     if (tiisg == 0) {
-        for (int row = 0; row < 2; ++row) {
+        for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
             dst_f32[first_row + row] = sumf1[row];
         }
     }
@@ -4774,7 +4825,7 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -4906,7 +4957,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = tot;
@@ -4951,6 +5002,10 @@ void kernel_mul_mv_q6_K_f32_impl(
 
     const int row = 2*r0 + sgitg;
 
+    if (row >= args.ne0) {
+        return;
+    }
+
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
@@ -5106,7 +5161,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5224,7 +5279,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5334,7 +5389,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.5f;
@@ -5446,7 +5501,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5559,7 +5614,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5659,7 +5714,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5754,7 +5809,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5844,7 +5899,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5933,7 +5988,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
diff --git a/ml/backend/ggml/ggml/src/ggml.c b/ml/backend/ggml/ggml/src/ggml.c
index 7ffcd907..635aa299 100644
--- a/ml/backend/ggml/ggml/src/ggml.c
+++ b/ml/backend/ggml/ggml/src/ggml.c
@@ -128,6 +128,10 @@ static void ggml_print_backtrace_symbols(void) {
 #endif
 
 static void ggml_print_backtrace(void) {
+    const char * GGML_NO_BACKTRACE = getenv("GGML_NO_BACKTRACE");
+    if (GGML_NO_BACKTRACE) {
+        return;
+    }
     char attach[32];
     snprintf(attach, sizeof(attach), "attach %d", getpid());
     int pid = fork();
@@ -236,7 +240,11 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
 
 
 void * ggml_aligned_malloc(size_t size) {
+#if defined(__s390x__)
+    const int alignment = 256;
+#else
     const int alignment = 64;
+#endif
 
 #if defined(_MSC_VER) || defined(__MINGW32__)
     return _aligned_malloc(size, alignment);
@@ -969,6 +977,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GET_REL_POS",
     "ADD_REL_POS",
     "RWKV_WKV6",
+    "GATED_LINEAR_ATTN",
 
     "UNARY",
 
@@ -988,7 +997,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1066,6 +1075,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "get_rel_pos(x)",
     "add_rel_pos(x)",
     "rwkv_wkv6(k, v, r, tf, td, s)",
+    "gated_linear_attn(k, v, q, gate, s)",
 
     "unary(x)",
 
@@ -1085,7 +1095,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1375,7 +1385,7 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso
         (t0->nb[3] == t1->nb[3]);
 }
 
-// check if t1 can be represented as a repeatition of t0
+// check if t1 can be represented as a repetition of t0
 bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
@@ -1590,15 +1600,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
 
     struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
 
-#ifdef __clang__
-    // temporary until ggml_tensor::backend is removed
-    #pragma clang diagnostic push
-    #pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#endif
-
     *result = (struct ggml_tensor) {
         /*.type         =*/ type,
-        /*.backend      =*/ GGML_BACKEND_TYPE_CPU,
         /*.buffer       =*/ NULL,
         /*.ne           =*/ { 1, 1, 1, 1 },
         /*.nb           =*/ { 0, 0, 0, 0 },
@@ -1614,10 +1617,6 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         /*.padding      =*/ { 0 },
     };
 
-#ifdef __clang__
-    #pragma clang diagnostic pop
-#endif
-
     // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
     //GGML_ASSERT_ALIGNED(result->data);
 
@@ -3461,12 +3460,14 @@ struct ggml_tensor * ggml_soft_max_ext(
     return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
 }
 
-// ggml_soft_max_back
+// ggml_soft_max_ext_back
 
-static struct ggml_tensor * ggml_soft_max_back_impl(
+static struct ggml_tensor * ggml_soft_max_ext_back_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
+        float                 scale,
+        float                 max_bias,
         bool                  inplace) {
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
@@ -3474,21 +3475,28 @@ static struct ggml_tensor * ggml_soft_max_back_impl(
     result->src[0] = a;
     result->src[1] = b;
 
+    memcpy((float *) result->op_params + 0, &scale,    sizeof(float));
+    memcpy((float *) result->op_params + 1, &max_bias, sizeof(float));
+
     return result;
 }
 
-struct ggml_tensor * ggml_soft_max_back(
+struct ggml_tensor * ggml_soft_max_ext_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
-    return ggml_soft_max_back_impl(ctx, a, b, false);
+        struct ggml_tensor  * b,
+        float                 scale,
+        float                 max_bias) {
+    return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, false);
 }
 
-struct ggml_tensor * ggml_soft_max_back_inplace(
+struct ggml_tensor * ggml_soft_max_ext_back_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
-        struct ggml_tensor  * b) {
-    return ggml_soft_max_back_impl(ctx, a, b, true);
+        struct ggml_tensor  * b,
+        float                 scale,
+        float                 max_bias) {
+    return ggml_soft_max_ext_back_impl(ctx, a, b, scale, max_bias, true);
 }
 
 // ggml_rope
@@ -3706,7 +3714,7 @@ void ggml_rope_yarn_corr_dims(
 
 // ggml_rope_back
 
-struct ggml_tensor * ggml_rope_back(
+struct ggml_tensor * ggml_rope_ext_back(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
@@ -3720,29 +3728,32 @@ struct ggml_tensor * ggml_rope_back(
         float                 attn_factor,
         float                 beta_fast,
         float                 beta_slow) {
-    GGML_ASSERT(ggml_is_vector(b));
-    GGML_ASSERT(b->type == GGML_TYPE_I32);
-    GGML_ASSERT(a->ne[2] == b->ne[0]);
-
-    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
-
-    int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
-    memcpy(params +  5, &freq_base,    sizeof(float));
-    memcpy(params +  6, &freq_scale,   sizeof(float));
-    memcpy(params +  7, &ext_factor,   sizeof(float));
-    memcpy(params +  8, &attn_factor,  sizeof(float));
-    memcpy(params +  9, &beta_fast,    sizeof(float));
-    memcpy(params + 10, &beta_slow,    sizeof(float));
-    ggml_set_op_params(result, params, sizeof(params));
-
-    result->op     = GGML_OP_ROPE_BACK;
-    result->src[0] = a;
-    result->src[1] = b;
-    result->src[2] = c;
-
+    struct ggml_tensor * result = ggml_rope_ext(
+        ctx, a, b, c, n_dims, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+    result->op = GGML_OP_ROPE_BACK;
     return result;
 }
 
+struct ggml_tensor * ggml_rope_multi_back(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c,
+        int                   n_dims,
+        int                   sections[4],
+        int                   mode,
+        int                   n_ctx_orig,
+        float                 freq_base,
+        float                 freq_scale,
+        float                 ext_factor,
+        float                 attn_factor,
+        float                 beta_fast,
+        float                 beta_slow) {
+    struct ggml_tensor * result = ggml_rope_multi(
+        ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+    result->op = GGML_OP_ROPE_BACK;
+    return result;
+}
 // ggml_clamp
 
 struct ggml_tensor * ggml_clamp(
@@ -4661,15 +4672,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
     GGML_ASSERT(ggml_is_contiguous(state));
 
     const int64_t S = k->ne[0];
-    const int64_t H = k->ne[2];
-    const int64_t n_tokens = k->ne[3];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
     const int64_t n_seqs = state->ne[1];
     {
-        GGML_ASSERT(k->ne[1] == 1);
-        GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
-        GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
-        // TODO: RWKV v4 and v5
-        GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
+        GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
         GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
     }
 
@@ -4688,6 +4697,49 @@ struct ggml_tensor * ggml_rwkv_wkv6(
     return result;
 }
 
+// ggml_gated_linear_attn
+
+struct ggml_tensor * ggml_gated_linear_attn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * state,
+        float scale) {
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S = k->ne[0];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
+    const int64_t n_seqs = state->ne[1];
+    {
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
+        GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
+        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    }
+
+    // concat output and new_state
+    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    ggml_set_op_params_f32(result, 0, scale);
+
+    result->op     = GGML_OP_GATED_LINEAR_ATTN;
+    result->src[0] = k;
+    result->src[1] = v;
+    result->src[2] = q;
+    result->src[3] = g;
+    result->src[4] = state;
+
+    return result;
+}
+
 // ggml_unary
 
 static struct ggml_tensor * ggml_unary_impl(
@@ -5062,10 +5114,10 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
         struct ggml_tensor  * a,
         struct ggml_tensor  * b,
         struct ggml_tensor  * c) {
-    GGML_ASSERT(ggml_are_same_shape(a, b));
-    GGML_ASSERT(ggml_is_scalar(c));
+    GGML_ASSERT(ggml_is_scalar(a));
+    GGML_ASSERT(ggml_are_same_shape(b, c));
 
-    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, b);
 
     result->op     = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
     result->src[0] = a;
@@ -5244,7 +5296,7 @@ static void ggml_sub_or_set(
 }
 
 static void ggml_compute_backward(
-        struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) {
+        struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) {
     struct ggml_tensor * tensor = cgraph->nodes[i];
     struct ggml_tensor * grad   = ggml_graph_get_grad(cgraph, tensor);
 
@@ -5316,7 +5368,7 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_MUL: {
             if (src0_needs_grads) {
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
             }
             if (src1_needs_grads) {
                 struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
@@ -5388,7 +5440,7 @@ static void ggml_compute_backward(
             if (src0_needs_grads) {
                 float eps;
                 memcpy(&eps, tensor->op_params, sizeof(float));
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, grad, src0, eps));
             }
         } break;
         case GGML_OP_MUL_MAT: {
@@ -5408,21 +5460,25 @@ static void ggml_compute_backward(
             // src1.shape   [n,p,qq,rr]
 
             if (src0_needs_grads) {
-                struct ggml_tensor * s1_tg =
+                GGML_ASSERT(grad->ne[2] == src1->ne[2]);
+                GGML_ASSERT(grad->ne[3] == src1->ne[3]);
+                struct ggml_tensor * tmp =
                     ggml_out_prod(ctx, // [n,m,qq,rr]
                         src1,          // [n,p,qq,rr]
                         grad);         // [m,p,qq,rr]
-                const int64_t qq = s1_tg->ne[2];
-                const int64_t rr = s1_tg->ne[3];
-                const int64_t q1 = src0->ne[2];
-                const int64_t r1 = src0->ne[3];
-                const bool ne2_broadcasted = qq > q1;
-                const bool ne3_broadcasted = rr > r1;
-                if (ne2_broadcasted || ne3_broadcasted) {
-                    // sum broadcast repetitions of s1_tg into shape of src0
-                    s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
+                if (!ggml_are_same_shape(tmp, src0)) {
+                    GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
+                    GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
+                    GGML_ASSERT(tmp->ne[3] == 1);
+
+                    const int64_t nr2 = tmp->ne[2] / src0->ne[2];
+                    const size_t nb2 = tmp->nb[2] * nr2;
+                    const size_t nb3 = tmp->nb[2];
+
+                    tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
+                    tmp = ggml_repeat_back(ctx, tmp, src0);
                 }
-                ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
+                ggml_add_or_set(ctx, cgraph, isrc0, tmp);
             }
             if (src1_needs_grads) {
                 ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5491,7 +5547,9 @@ static void ggml_compute_backward(
             if (src0_needs_grads) {
                 GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
                 GGML_ASSERT(ggml_is_contiguous(grad));
-                ggml_add_or_set(ctx, cgraph, isrc0, grad);
+                GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
+                ggml_add_or_set(ctx, cgraph, isrc0,
+                    ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
             }
         } break;
         case GGML_OP_RESHAPE: {
@@ -5571,7 +5629,13 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_SOFT_MAX: {
             if (src0_needs_grads) {
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor));
+                float scale    = 1.0f;
+                float max_bias = 0.0f;
+
+                memcpy(&scale,    (const float *) tensor->op_params + 0, sizeof(float));
+                memcpy(&max_bias, (const float *) tensor->op_params + 1, sizeof(float));
+
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_ext_back(ctx, grad, tensor, scale, max_bias));
             }
             GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented");
         } break;
@@ -5583,6 +5647,7 @@ static void ggml_compute_backward(
                 //const int n_ctx      = ((int32_t *) tensor->op_params)[3];
                 const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4];
                 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+                int sections[4] = {0, 0, 0, 0};
 
                 memcpy(&freq_base,   (const float *) tensor->op_params +  5, sizeof(float));
                 memcpy(&freq_scale,  (const float *) tensor->op_params +  6, sizeof(float));
@@ -5590,10 +5655,14 @@ static void ggml_compute_backward(
                 memcpy(&attn_factor, (const float *) tensor->op_params +  8, sizeof(float));
                 memcpy(&beta_fast,   (const float *) tensor->op_params +  9, sizeof(float));
                 memcpy(&beta_slow,   (const float *) tensor->op_params + 10, sizeof(float));
+                memcpy(§ions,                    tensor->op_params + 11, sizeof(sections));
 
-                ggml_add_or_set(ctx, cgraph, isrc0,
-                    ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base,
-                        freq_scale, ext_factor, attn_factor, beta_fast, beta_slow));
+                struct ggml_tensor * rope_back = grad->ne[2] == src1->ne[0] ?
+                    ggml_rope_ext_back(ctx, grad, src1, src2, n_dims,
+                        mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow) :
+                    ggml_rope_multi_back(ctx, grad, src1, src2, n_dims, sections,
+                        mode, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
+                ggml_add_or_set(ctx, cgraph, isrc0, rope_back);
             }
             GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented");
         } break;
@@ -5607,7 +5676,7 @@ static void ggml_compute_backward(
                 const int32_t d1    = ggml_get_op_params_i32(tensor, 5);
                 const bool    is_2D = ggml_get_op_params_i32(tensor, 6) == 1;
 
-                ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
+                ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, grad, src0, src1->ne, s0, s1, p0, p1, d0, d1, is_2D));
             }
         } break;
         case GGML_OP_POOL_2D: {
@@ -5650,7 +5719,7 @@ static void ggml_compute_backward(
                 } break;
                 case GGML_UNARY_OP_SILU: {
                     if (src0_needs_grads) {
-                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad));
+                        ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, grad, src0));
                     }
                 } break;
                 case GGML_UNARY_OP_EXP: {
@@ -5667,7 +5736,7 @@ static void ggml_compute_backward(
         } break;
         case GGML_OP_CROSS_ENTROPY_LOSS: {
             if (src0_needs_grads) {
-                ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad));
+                ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, grad, src0, src1));
             }
             GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented");
         } break;
@@ -6438,1271 +6507,6 @@ size_t ggml_quantize_chunk(
 
 ////////////////////////////////////////////////////////////////////////////////
 
-struct gguf_str {
-    uint64_t n;  // GGUFv2
-    char * data;
-};
-
-static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
-    [GGUF_TYPE_UINT8]   = sizeof(uint8_t),
-    [GGUF_TYPE_INT8]    = sizeof(int8_t),
-    [GGUF_TYPE_UINT16]  = sizeof(uint16_t),
-    [GGUF_TYPE_INT16]   = sizeof(int16_t),
-    [GGUF_TYPE_UINT32]  = sizeof(uint32_t),
-    [GGUF_TYPE_INT32]   = sizeof(int32_t),
-    [GGUF_TYPE_FLOAT32] = sizeof(float),
-    [GGUF_TYPE_BOOL]    = sizeof(bool),
-    [GGUF_TYPE_STRING]  = sizeof(struct gguf_str),
-    [GGUF_TYPE_UINT64]  = sizeof(uint64_t),
-    [GGUF_TYPE_INT64]   = sizeof(int64_t),
-    [GGUF_TYPE_FLOAT64] = sizeof(double),
-    [GGUF_TYPE_ARRAY]   = 0, // undefined
-};
-static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
-
-static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
-    [GGUF_TYPE_UINT8]   = "u8",
-    [GGUF_TYPE_INT8]    = "i8",
-    [GGUF_TYPE_UINT16]  = "u16",
-    [GGUF_TYPE_INT16]   = "i16",
-    [GGUF_TYPE_UINT32]  = "u32",
-    [GGUF_TYPE_INT32]   = "i32",
-    [GGUF_TYPE_FLOAT32] = "f32",
-    [GGUF_TYPE_BOOL]    = "bool",
-    [GGUF_TYPE_STRING]  = "str",
-    [GGUF_TYPE_ARRAY]   = "arr",
-    [GGUF_TYPE_UINT64]  = "u64",
-    [GGUF_TYPE_INT64]   = "i64",
-    [GGUF_TYPE_FLOAT64] = "f64",
-};
-static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
-
-union gguf_value {
-    uint8_t  uint8;
-    int8_t   int8;
-    uint16_t uint16;
-    int16_t  int16;
-    uint32_t uint32;
-    int32_t  int32;
-    float    float32;
-    uint64_t uint64;
-    int64_t  int64;
-    double   float64;
-    bool     bool_;
-
-    struct gguf_str str;
-
-    struct {
-        enum gguf_type type;
-
-        uint64_t n;  // GGUFv2
-        void * data;
-    } arr;
-};
-
-struct gguf_kv {
-    struct gguf_str key;
-
-    enum  gguf_type  type;
-    union gguf_value value;
-};
-
-struct gguf_header {
-    char magic[4];
-
-    uint32_t version;
-    uint64_t n_tensors; // GGUFv2
-    uint64_t n_kv;      // GGUFv2
-};
-
-struct gguf_tensor_info {
-    struct gguf_str name;
-
-    uint32_t n_dims;
-    uint64_t ne[GGML_MAX_DIMS];
-
-    enum ggml_type type;
-
-    uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT`
-
-    // for writing API
-    const void * data;
-    size_t size;
-};
-
-struct gguf_context {
-    struct gguf_header header;
-
-    struct gguf_kv          * kv;
-    struct gguf_tensor_info * infos;
-
-    size_t alignment;
-    size_t offset;    // offset of `data` from beginning of file
-    size_t size;      // size of `data` in bytes
-
-    //uint8_t * padding;
-    void * data;
-};
-
-size_t gguf_type_size(enum gguf_type type) {
-    GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT);
-    return GGUF_TYPE_SIZE[type];
-}
-
-static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
-    if (info->n_dims > GGML_MAX_DIMS) {
-        fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
-        return false;
-    }
-
-    if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
-        fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
-        return false;
-    }
-
-    if (strlen(info->name.data) >= GGML_MAX_NAME) {
-        fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
-        return false;
-    }
-
-    for (uint32_t i = 0; i < info->n_dims; ++i) {
-        if (info->ne[i] <= 0) {
-            fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
-            return false;
-        }
-    }
-
-    // prevent overflow for total number of elements
-    if (INT64_MAX/info->ne[1] <= info->ne[0]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
-        return false;
-    }
-
-    if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
-        return false;
-    }
-
-    if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
-        return false;
-    }
-
-    return true;
-}
-
-static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
-    const size_t n = fread(dst, 1, size, file);
-    *offset += n;
-    return n == size;
-}
-
-static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
-    p->n    = 0;
-    p->data = NULL;
-
-    bool ok = true;
-
-    ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset);
-
-    // early exit if string length is invalid, prevents from integer overflow
-    if (p->n == SIZE_MAX) {
-        fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n);
-        return false;
-    }
-
-    p->data = calloc(p->n + 1, 1);
-    if (!p->data) {
-        fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n);
-        return false;
-    }
-
-    ok = ok && gguf_fread_el(file,  p->data, p->n, offset);
-
-    return ok;
-}
-
-static void gguf_free_kv(struct gguf_kv * kv) {
-    if (kv->key.data) {
-        GGML_FREE(kv->key.data);
-    }
-
-    if (kv->type == GGUF_TYPE_STRING) {
-        if (kv->value.str.data) {
-            GGML_FREE(kv->value.str.data);
-        }
-    }
-
-    if (kv->type == GGUF_TYPE_ARRAY) {
-        if (kv->value.arr.data) {
-            if (kv->value.arr.type == GGUF_TYPE_STRING) {
-                for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
-                    struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j];
-                    if (str->data) {
-                        GGML_FREE(str->data);
-                    }
-                }
-            }
-            GGML_FREE(kv->value.arr.data);
-        }
-    }
-}
-
-struct gguf_context * gguf_init_empty(void) {
-    struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
-    if (!ctx) {
-        fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
-        return NULL;
-    }
-
-    memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
-    ctx->header.version   = GGUF_VERSION;
-    ctx->header.n_tensors = 0;
-    ctx->header.n_kv      = 0;
-
-    ctx->kv    = NULL;
-    ctx->infos = NULL;
-
-    ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
-    ctx->offset    = 0;
-    ctx->size      = 0;
-
-    ctx->data = NULL;
-
-    return ctx;
-}
-
-struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
-    // offset from start of file
-    size_t offset = 0;
-
-    char magic[4];
-
-    // check the magic before making allocations
-    {
-        gguf_fread_el(file, &magic, sizeof(magic), &offset);
-
-        for (uint32_t i = 0; i < sizeof(magic); i++) {
-            if (magic[i] != GGUF_MAGIC[i]) {
-                fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
-                return NULL;
-            }
-        }
-    }
-
-    bool ok = true;
-
-    struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
-    if (!ctx) {
-        fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
-        return NULL;
-    }
-
-    // read the header
-    {
-        strncpy(ctx->header.magic, magic, 4);
-
-        ctx->kv    = NULL;
-        ctx->infos = NULL;
-        ctx->data  = NULL;
-
-        ok = ok && gguf_fread_el(file, &ctx->header.version,   sizeof(ctx->header.version),   &offset);
-        ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
-        ok = ok && gguf_fread_el(file, &ctx->header.n_kv,      sizeof(ctx->header.n_kv),      &offset);
-
-        if (ctx->header.version == 1) {
-            fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        // sanity-checks to prevent from integer/buffer overflows
-
-        ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_tensor_info));
-        ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_tensor_overhead());
-        ok = ok && (ctx->header.n_kv      < (SIZE_MAX/2)/sizeof(struct gguf_kv));
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read header\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-    }
-
-    // read the kv pairs
-    {
-        const uint64_t n_kv = ctx->header.n_kv;
-
-        if (n_kv > 0) {
-            ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
-            if (!ctx->kv) {
-                fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
-                gguf_free(ctx);
-                return NULL;
-            }
-        }
-
-        for (uint64_t i = 0; i < n_kv; ++i) {
-            struct gguf_kv * kv = &ctx->kv[i];
-
-            //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
-
-            ok = ok && gguf_fread_str(file, &kv->key,                    &offset);
-            ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset);
-
-            //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
-
-            switch (kv->type) {
-                case GGUF_TYPE_UINT8:   ok = ok && gguf_fread_el (file, &kv->value.uint8,   sizeof(kv->value.uint8),   &offset); break;
-                case GGUF_TYPE_INT8:    ok = ok && gguf_fread_el (file, &kv->value.int8,    sizeof(kv->value.int8),    &offset); break;
-                case GGUF_TYPE_UINT16:  ok = ok && gguf_fread_el (file, &kv->value.uint16,  sizeof(kv->value.uint16),  &offset); break;
-                case GGUF_TYPE_INT16:   ok = ok && gguf_fread_el (file, &kv->value.int16,   sizeof(kv->value.int16),   &offset); break;
-                case GGUF_TYPE_UINT32:  ok = ok && gguf_fread_el (file, &kv->value.uint32,  sizeof(kv->value.uint32),  &offset); break;
-                case GGUF_TYPE_INT32:   ok = ok && gguf_fread_el (file, &kv->value.int32,   sizeof(kv->value.int32),   &offset); break;
-                case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
-                case GGUF_TYPE_UINT64:  ok = ok && gguf_fread_el (file, &kv->value.uint64,  sizeof(kv->value.uint64),  &offset); break;
-                case GGUF_TYPE_INT64:   ok = ok && gguf_fread_el (file, &kv->value.int64,   sizeof(kv->value.int64),   &offset); break;
-                case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
-                case GGUF_TYPE_BOOL:    ok = ok && gguf_fread_el (file, &kv->value.bool_,   sizeof(kv->value.bool_),   &offset); break;
-                case GGUF_TYPE_STRING:  ok = ok && gguf_fread_str(file, &kv->value.str,                                &offset); break;
-                case GGUF_TYPE_ARRAY:
-                    {
-                        ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
-                        ok = ok && gguf_fread_el(file, &kv->value.arr.n,    sizeof(kv->value.arr.n),    &offset);
-
-                        switch (kv->value.arr.type) {
-                            case GGUF_TYPE_UINT8:
-                            case GGUF_TYPE_INT8:
-                            case GGUF_TYPE_UINT16:
-                            case GGUF_TYPE_INT16:
-                            case GGUF_TYPE_UINT32:
-                            case GGUF_TYPE_INT32:
-                            case GGUF_TYPE_FLOAT32:
-                            case GGUF_TYPE_UINT64:
-                            case GGUF_TYPE_INT64:
-                            case GGUF_TYPE_FLOAT64:
-                            case GGUF_TYPE_BOOL:
-                                {
-                                    // prevent from integer overflow in the malloc below
-                                    if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) {
-                                        fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
-                                    if (!kv->value.arr.data) {
-                                        fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
-                                } break;
-                            case GGUF_TYPE_STRING:
-                                {
-                                    // prevent from integer overflow in the malloc below
-                                    if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) {
-                                        fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str));
-                                    if (!kv->value.arr.data) {
-                                        fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
-                                        ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
-                                    }
-                                } break;
-                            case GGUF_TYPE_ARRAY:
-                            default:
-                                {
-                                    fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type);
-                                    ok = false;
-                                } break;
-                        }
-                    } break;
-                default:
-                    {
-                        fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type);
-                        ok = false;
-                    } break;
-            }
-
-            if (!ok) {
-                break;
-            }
-        }
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-    }
-
-    // read the tensor infos
-    if (ctx->header.n_tensors > 0) {
-        ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
-        if (!ctx->infos) {
-            fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            for (int j = 0; j < GGML_MAX_DIMS; ++j) {
-                info->ne[j] = 1;
-            }
-
-            ok = ok && gguf_fread_str(file, &info->name,                          &offset);
-            ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims),  &offset);
-
-            ok = ok && (info->n_dims <= GGML_MAX_DIMS);
-
-            for (uint32_t j = 0; j < info->n_dims; ++j) {
-                ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
-            }
-
-            ok = ok && gguf_fread_el (file, &info->type,   sizeof(info->type),    &offset);
-            ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset),  &offset);
-
-            ok = ok && gguf_tensor_info_sanitize(info);
-
-            // make sure there is no duplicated tensor names
-            for (uint64_t j = 0; j < i && ok; ++j) {
-                if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
-                    fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
-                    ok = false;
-                }
-            }
-
-            if (!ok) {
-                fprintf(stderr, "%s: failed to read tensor info\n", __func__);
-                gguf_free(ctx);
-                return NULL;
-            }
-        }
-    }
-
-    ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
-
-    int alignment_idx = gguf_find_key(ctx, "general.alignment");
-    if (alignment_idx != -1) {
-        ctx->alignment = gguf_get_val_u32(ctx, alignment_idx);
-    }
-
-    // we require the data section to be aligned, so take into account any padding
-    {
-        const size_t offset_pad = offset % ctx->alignment;
-
-        if (offset_pad != 0) {
-            offset += ctx->alignment - offset_pad;
-            fseek(file, offset, SEEK_SET);
-        }
-    }
-
-    // store the current file offset - this is where the data section starts
-    ctx->offset = offset;
-
-    // compute the total size of the data section, taking into account the alignment
-    {
-        ctx->size = 0;
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            const int64_t ne =
-                (int64_t) info->ne[0] *
-                (int64_t) info->ne[1] *
-                (int64_t) info->ne[2] *
-                (int64_t) info->ne[3];
-
-            if (ggml_blck_size(info->type) == 0 ) {
-                // this tensor type support have been removed:
-                fprintf(stderr, "%s: tensor '%s' of type %d: %s\n",
-                        __func__, info->name.data, (int) info->type, ggml_type_name(info->type));
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            if (ne % ggml_blck_size(info->type) != 0) {
-                fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
-                        __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            const size_t size_cur = ggml_row_size(info->type, ne);
-
-            ctx->size += GGML_PAD(size_cur, ctx->alignment);
-        }
-    }
-
-    // load the tensor data only if requested
-    if (params.ctx != NULL) {
-        // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
-        // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
-        // the ggml_tensor structs to the appropriate locations in the binary blob
-
-        // compute the exact size needed for the new ggml_context
-        const size_t mem_size =
-            params.no_alloc ?
-            (ctx->header.n_tensors    )*ggml_tensor_overhead() :
-            (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
-
-        struct ggml_init_params pdata = {
-            .mem_size   = mem_size,
-            .mem_buffer = NULL,
-            .no_alloc   = params.no_alloc,
-        };
-
-        *params.ctx = ggml_init(pdata);
-        if (*params.ctx == NULL) {
-            fprintf(stderr, "%s: failed to initialize context\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        struct ggml_context * ctx_data = *params.ctx;
-
-        struct ggml_tensor * data = NULL;
-
-        if (!params.no_alloc) {
-            data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
-
-            ok = ok && data != NULL;
-
-            // read the binary blob with the tensor data
-            ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset);
-
-            if (!ok) {
-                fprintf(stderr, "%s: failed to read tensor data\n", __func__);
-                ggml_free(ctx_data);
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            ctx->data = data->data;
-        }
-
-        ggml_set_no_alloc(ctx_data, true);
-
-        // create the tensors
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            const int64_t ne[GGML_MAX_DIMS] = {
-                ctx->infos[i].ne[0],
-                ctx->infos[i].ne[1],
-                ctx->infos[i].ne[2],
-                ctx->infos[i].ne[3],
-            };
-
-            struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne);
-
-            ok = ok && cur != NULL;
-
-            if (!ok) {
-                break;
-            }
-
-            ggml_set_name(cur, ctx->infos[i].name.data);
-
-            // point the data member to the appropriate location in the binary blob using the tensor infos
-            if (!params.no_alloc) {
-              //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
-                cur->data = (char *) data->data + ctx->infos[i].offset;               // offset from data
-            }
-        }
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
-            ggml_free(ctx_data);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        ggml_set_no_alloc(ctx_data, params.no_alloc);
-    }
-
-    return ctx;
-}
-
-struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
-    FILE * file = ggml_fopen(fname, "rb");
-    if (!file) {
-        fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
-        return NULL;
-    }
-
-    struct gguf_context * result = gguf_init_from_file_impl(file, params);
-    fclose(file);
-    return result;
-}
-
-void gguf_free(struct gguf_context * ctx) {
-    if (ctx == NULL) {
-        return;
-    }
-
-    if (ctx->kv) {
-        // free string memory - not great..
-        for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
-            gguf_free_kv(&ctx->kv[i]);
-        }
-
-        GGML_FREE(ctx->kv);
-    }
-
-    if (ctx->infos) {
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            if (info->name.data) {
-                GGML_FREE(info->name.data);
-            }
-        }
-
-        GGML_FREE(ctx->infos);
-    }
-
-    GGML_FREE(ctx);
-}
-
-const char * gguf_type_name(enum gguf_type type) {
-    return GGUF_TYPE_NAME[type];
-}
-
-int gguf_get_version(const struct gguf_context * ctx) {
-    return ctx->header.version;
-}
-
-size_t gguf_get_alignment(const struct gguf_context * ctx) {
-    return ctx->alignment;
-}
-
-size_t gguf_get_data_offset(const struct gguf_context * ctx) {
-    return ctx->offset;
-}
-
-void * gguf_get_data(const struct gguf_context * ctx) {
-    return ctx->data;
-}
-
-int gguf_get_n_kv(const struct gguf_context * ctx) {
-    return ctx->header.n_kv;
-}
-
-int gguf_find_key(const struct gguf_context * ctx, const char * key) {
-    // return -1 if key not found
-    int keyfound = -1;
-
-    const int n_kv = gguf_get_n_kv(ctx);
-
-    for (int i = 0; i < n_kv; ++i) {
-        if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
-            keyfound = i;
-            break;
-        }
-    }
-
-    return keyfound;
-}
-
-const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    return ctx->kv[key_id].key.data;
-}
-
-enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    return ctx->kv[key_id].type;
-}
-
-enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.type;
-}
-
-const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.data;
-}
-
-const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    struct gguf_kv * kv = &ctx->kv[key_id];
-    struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
-    return str->data;
-}
-
-int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.n;
-}
-
-uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
-    return ctx->kv[key_id].value.uint8;
-}
-
-int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
-    return ctx->kv[key_id].value.int8;
-}
-
-uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
-    return ctx->kv[key_id].value.uint16;
-}
-
-int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
-    return ctx->kv[key_id].value.int16;
-}
-
-uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
-    return ctx->kv[key_id].value.uint32;
-}
-
-int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
-    return ctx->kv[key_id].value.int32;
-}
-
-float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
-    return ctx->kv[key_id].value.float32;
-}
-
-uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
-    return ctx->kv[key_id].value.uint64;
-}
-
-int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
-    return ctx->kv[key_id].value.int64;
-}
-
-double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
-    return ctx->kv[key_id].value.float64;
-}
-
-bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
-    return ctx->kv[key_id].value.bool_;
-}
-
-const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
-    return ctx->kv[key_id].value.str.data;
-}
-
-const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
-    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
-    return &ctx->kv[key_id].value;
-}
-
-int gguf_get_n_tensors(const struct gguf_context * ctx) {
-    return ctx->header.n_tensors;
-}
-
-int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
-    // return -1 if tensor not found
-    int tensorfound = -1;
-
-    const int n_tensors = gguf_get_n_tensors(ctx);
-
-    for (int i = 0; i < n_tensors; ++i) {
-        if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
-            tensorfound = i;
-            break;
-        }
-    }
-
-    return tensorfound;
-}
-
-size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].offset;
-}
-
-char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].name.data;
-}
-
-enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].type;
-}
-
-// returns the index
-static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) {
-    const int idx = gguf_find_key(ctx, key);
-    if (idx >= 0) {
-        return idx;
-    }
-
-    const int n_kv = gguf_get_n_kv(ctx);
-
-    ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv));
-    ctx->kv[n_kv].key.n    = strlen(key);
-    ctx->kv[n_kv].key.data = strdup(key);
-    ctx->header.n_kv++;
-
-    return n_kv;
-}
-
-void gguf_remove_key(struct gguf_context * ctx, const char * key) {
-    const int idx = gguf_find_key(ctx, key);
-    if (idx >= 0) {
-        const int n_kv = gguf_get_n_kv(ctx);
-        gguf_free_kv(&ctx->kv[idx]);
-        for (int i = idx; i < n_kv-1; ++i) {
-            ctx->kv[i] = ctx->kv[i+1];
-        }
-        ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv));
-        ctx->header.n_kv--;
-    }
-}
-
-void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_UINT8;
-    ctx->kv[idx].value.uint8 = val;
-}
-
-void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type       = GGUF_TYPE_INT8;
-    ctx->kv[idx].value.int8 = val;
-}
-
-void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT16;
-    ctx->kv[idx].value.uint16 = val;
-}
-
-void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT16;
-    ctx->kv[idx].value.int16 = val;
-}
-
-void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT32;
-    ctx->kv[idx].value.uint32 = val;
-}
-
-void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT32;
-    ctx->kv[idx].value.int32 = val;
-}
-
-void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type          = GGUF_TYPE_FLOAT32;
-    ctx->kv[idx].value.float32 = val;
-}
-
-void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT64;
-    ctx->kv[idx].value.uint64 = val;
-}
-
-void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT64;
-    ctx->kv[idx].value.int64 = val;
-}
-
-void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type          = GGUF_TYPE_FLOAT64;
-    ctx->kv[idx].value.float64 = val;
-}
-
-void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_BOOL;
-    ctx->kv[idx].value.bool_ = val;
-}
-
-void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_STRING;
-    ctx->kv[idx].value.str.n    = strlen(val);
-    ctx->kv[idx].value.str.data = strdup(val);
-}
-
-void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_ARRAY;
-    ctx->kv[idx].value.arr.type = type;
-    ctx->kv[idx].value.arr.n    = n;
-    ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
-    memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
-}
-
-void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_ARRAY;
-    ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
-    ctx->kv[idx].value.arr.n    = n;
-    ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
-    for (int i = 0; i < n; i++) {
-        struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
-        str->n    = strlen(data[i]);
-        str->data = strdup(data[i]);
-    }
-}
-
-// set or add KV pairs from another context
-void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
-    for (uint32_t i = 0; i < src->header.n_kv; i++) {
-        switch (src->kv[i].type) {
-            case GGUF_TYPE_UINT8:   gguf_set_val_u8  (ctx, src->kv[i].key.data, src->kv[i].value.uint8);    break;
-            case GGUF_TYPE_INT8:    gguf_set_val_i8  (ctx, src->kv[i].key.data, src->kv[i].value.int8);     break;
-            case GGUF_TYPE_UINT16:  gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16);   break;
-            case GGUF_TYPE_INT16:   gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16);    break;
-            case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32);   break;
-            case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32);    break;
-            case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32);  break;
-            case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64);   break;
-            case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64);    break;
-            case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64);  break;
-            case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_);    break;
-            case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
-            case GGUF_TYPE_ARRAY:
-                {
-                    if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
-                        const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
-                        for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
-                            data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
-                        }
-                        gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
-                        GGML_FREE((void *)data);
-                    } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
-                        GGML_ABORT("nested arrays not supported");
-                    } else {
-                        gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n);
-                    }
-                } break;
-            default: GGML_ABORT("invalid type");
-        }
-    }
-}
-
-void gguf_add_tensor(
-             struct gguf_context * ctx,
-        const struct ggml_tensor * tensor) {
-    GGML_ASSERT(tensor);
-    if (gguf_find_tensor(ctx, tensor->name) != -1) {
-        GGML_ABORT("duplicated tensor name");
-    }
-
-    const int idx = ctx->header.n_tensors;
-    ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
-
-    ctx->infos[idx].name.n    = strlen(tensor->name);
-    ctx->infos[idx].name.data = strdup(tensor->name);
-
-    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
-        ctx->infos[idx].ne[i] = 1;
-    }
-
-    ctx->infos[idx].n_dims = ggml_n_dims(tensor);
-    for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
-        ctx->infos[idx].ne[i] = tensor->ne[i];
-    }
-
-    ctx->infos[idx].type   = tensor->type;
-    ctx->infos[idx].offset = 0;
-    ctx->infos[idx].data   = tensor->data;
-    ctx->infos[idx].size   = ggml_nbytes(tensor);
-
-    if (ctx->header.n_tensors > 0) {
-        ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment);
-    }
-
-    ctx->header.n_tensors++;
-}
-
-void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
-    const int idx = gguf_find_tensor(ctx, name);
-    if (idx < 0) {
-        GGML_ABORT("tensor not found");
-    }
-
-    ctx->infos[idx].type = type;
-}
-
-void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) {
-    const int idx = gguf_find_tensor(ctx, name);
-    if (idx < 0) {
-        GGML_ABORT("tensor not found");
-    }
-
-    ctx->infos[idx].data = data;
-    ctx->infos[idx].size = size;
-
-    // update offsets
-    for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) {
-        ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment);
-    }
-}
-
-//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) {
-//    fwrite(&val->n,   sizeof(val->n),    1, file);
-//    fwrite(val->data, sizeof(char), val->n, file);
-//}
-//
-//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) {
-//    fwrite(val, sizeof(char), size, file);
-//}
-
-struct gguf_buf gguf_buf_init(size_t size) {
-    struct gguf_buf buf = {
-        /*buf.data   =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
-        /*buf.size   =*/ size,
-        /*buf.offset =*/ 0,
-    };
-
-    return buf;
-}
-
-void gguf_buf_free(struct gguf_buf buf) {
-    if (buf.data) {
-        GGML_FREE(buf.data);
-    }
-}
-
-static void gguf_buf_grow(struct gguf_buf * buf, size_t size) {
-    if (buf->offset + size > buf->size) {
-        buf->size = 1.5*(buf->offset + size);
-        if (buf->data) {
-            buf->data = realloc(buf->data, buf->size);
-        }
-    }
-}
-
-static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) {
-    gguf_buf_grow(buf, sizeof(val->n) + val->n);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n));
-    }
-    buf->offset += sizeof(val->n);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, val->data, val->n);
-    }
-    buf->offset += val->n;
-}
-
-static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) {
-    gguf_buf_grow(buf, el_size);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, val, el_size);
-    }
-    buf->offset += el_size;
-}
-
-void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
-    // write header
-    gguf_bwrite_el(buf, &ctx->header.magic,     sizeof(ctx->header.magic));
-    gguf_bwrite_el(buf, &ctx->header.version,   sizeof(ctx->header.version));
-    gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors));
-    gguf_bwrite_el(buf, &ctx->header.n_kv,      sizeof(ctx->header.n_kv));
-
-    // write key-value pairs
-    for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
-        struct gguf_kv * kv = &ctx->kv[i];
-
-        gguf_bwrite_str(buf, &kv->key);
-        gguf_bwrite_el (buf, &kv->type, sizeof(kv->type));
-
-        switch (kv->type) {
-            case GGUF_TYPE_UINT8:   gguf_bwrite_el( buf, &kv->value.uint8,   sizeof(kv->value.uint8)  ); break;
-            case GGUF_TYPE_INT8:    gguf_bwrite_el (buf, &kv->value.int8,    sizeof(kv->value.int8)   ); break;
-            case GGUF_TYPE_UINT16:  gguf_bwrite_el (buf, &kv->value.uint16,  sizeof(kv->value.uint16) ); break;
-            case GGUF_TYPE_INT16:   gguf_bwrite_el (buf, &kv->value.int16,   sizeof(kv->value.int16)  ); break;
-            case GGUF_TYPE_UINT32:  gguf_bwrite_el (buf, &kv->value.uint32,  sizeof(kv->value.uint32) ); break;
-            case GGUF_TYPE_INT32:   gguf_bwrite_el (buf, &kv->value.int32,   sizeof(kv->value.int32)  ); break;
-            case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
-            case GGUF_TYPE_UINT64:  gguf_bwrite_el (buf, &kv->value.uint64,  sizeof(kv->value.uint64) ); break;
-            case GGUF_TYPE_INT64:   gguf_bwrite_el (buf, &kv->value.int64,   sizeof(kv->value.int64)  ); break;
-            case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
-            case GGUF_TYPE_BOOL:    gguf_bwrite_el (buf, &kv->value.bool_,   sizeof(kv->value.bool_)  ); break;
-            case GGUF_TYPE_STRING:  gguf_bwrite_str(buf, &kv->value.str                               ); break;
-            case GGUF_TYPE_ARRAY:
-                {
-                    gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type));
-                    gguf_bwrite_el(buf, &kv->value.arr.n,    sizeof(kv->value.arr.n)   );
-
-                    switch (kv->value.arr.type) {
-                        case GGUF_TYPE_UINT8:
-                        case GGUF_TYPE_INT8:
-                        case GGUF_TYPE_UINT16:
-                        case GGUF_TYPE_INT16:
-                        case GGUF_TYPE_UINT32:
-                        case GGUF_TYPE_INT32:
-                        case GGUF_TYPE_FLOAT32:
-                        case GGUF_TYPE_UINT64:
-                        case GGUF_TYPE_INT64:
-                        case GGUF_TYPE_FLOAT64:
-                        case GGUF_TYPE_BOOL:
-                            {
-                                gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type));
-                            } break;
-                        case GGUF_TYPE_STRING:
-                            {
-                                for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
-                                    gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]);
-                                }
-                            } break;
-                        case GGUF_TYPE_ARRAY:
-                        default: GGML_ABORT("invalid type");
-                    }
-                } break;
-            default: GGML_ABORT("invalid type");
-        }
-    }
-
-    // write tensor infos
-    for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
-        struct gguf_tensor_info * info = &ctx->infos[i];
-
-        gguf_bwrite_str(buf, &info->name);
-        gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims));
-        for (uint32_t j = 0; j < info->n_dims; ++j) {
-            gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j]));
-        }
-        gguf_bwrite_el(buf, &info->type,   sizeof(info->type));
-        gguf_bwrite_el(buf, &info->offset, sizeof(info->offset));
-    }
-
-    // we require the data section to be aligned, so take into account any padding
-    {
-        const size_t offset     = buf->offset;
-        const size_t offset_pad = GGML_PAD(offset, ctx->alignment);
-
-        if (offset_pad != offset) {
-            uint8_t pad = 0;
-            for (size_t i = 0; i < offset_pad - offset; ++i) {
-                gguf_bwrite_el(buf, &pad, sizeof(pad));
-            }
-        }
-    }
-
-    if (only_meta) {
-        return;
-    }
-
-    size_t offset = 0;
-
-    // write tensor data
-    for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
-        struct gguf_tensor_info * info = &ctx->infos[i];
-
-        const size_t size     = info->size;
-        const size_t size_pad = GGML_PAD(size, ctx->alignment);
-
-        gguf_bwrite_el(buf, info->data, size);
-
-        if (size_pad != size) {
-            uint8_t pad = 0;
-            for (size_t j = 0; j < size_pad - size; ++j) {
-                gguf_bwrite_el(buf, &pad, sizeof(pad));
-            }
-        }
-
-        GGML_ASSERT(offset == info->offset);
-
-        offset += size_pad;
-    }
-}
-
-void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
-    FILE * file = ggml_fopen(fname, "wb");
-    if (!file) {
-        GGML_ABORT("failed to open file for writing");
-    }
-
-    struct gguf_buf buf = gguf_buf_init(16*1024);
-
-    gguf_write_to_buf(ctx, &buf, only_meta);
-
-    fwrite(buf.data, 1, buf.offset, file);
-
-    gguf_buf_free(buf);
-
-    fclose(file);
-}
-
-size_t gguf_get_meta_size(const struct gguf_context * ctx) {
-    // no allocs - only compute size
-    struct gguf_buf buf = gguf_buf_init(0);
-
-    gguf_write_to_buf(ctx, &buf, true);
-
-    return buf.offset;
-}
-
-void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
-    struct gguf_buf buf = gguf_buf_init(16*1024);
-
-    gguf_write_to_buf(ctx, &buf, true);
-
-    memcpy(data, buf.data, buf.offset);
-
-    gguf_buf_free(buf);
-}
-
 void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
     g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
     g_logger_state.log_callback_user_data = user_data;
diff --git a/ml/backend/ggml/ggml/src/ggml_darwin_arm64.go b/ml/backend/ggml/ggml/src/ggml_darwin_arm64.go
index beffa64e..d4354c60 100644
--- a/ml/backend/ggml/ggml/src/ggml_darwin_arm64.go
+++ b/ml/backend/ggml/ggml/src/ggml_darwin_arm64.go
@@ -1,6 +1,6 @@
 package ggml
 
-// #cgo CPPFLAGS: -DGGML_USE_METAL -DGGML_USE_BLAS
+// #cgo CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_BLAS
 // #cgo LDFLAGS: -framework Foundation
 import "C"
 
diff --git a/ml/backend/ggml/ggml/src/gguf.cpp b/ml/backend/ggml/ggml/src/gguf.cpp
new file mode 100644
index 00000000..ab13669c
--- /dev/null
+++ b/ml/backend/ggml/ggml/src/gguf.cpp
@@ -0,0 +1,1329 @@
+#include "ggml.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "gguf.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+template 
+struct type_to_gguf_type;
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT8;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT8;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT16;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT16;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT32;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT32;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_FLOAT32;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_BOOL;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_STRING;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT64;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT64;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_FLOAT64;
+};
+
+static const std::map GGUF_TYPE_SIZE = {
+    {GGUF_TYPE_UINT8,   sizeof(uint8_t)},
+    {GGUF_TYPE_INT8,    sizeof(int8_t)},
+    {GGUF_TYPE_UINT16,  sizeof(uint16_t)},
+    {GGUF_TYPE_INT16,   sizeof(int16_t)},
+    {GGUF_TYPE_UINT32,  sizeof(uint32_t)},
+    {GGUF_TYPE_INT32,   sizeof(int32_t)},
+    {GGUF_TYPE_FLOAT32, sizeof(float)},
+    {GGUF_TYPE_BOOL,    sizeof(int8_t)},
+    {GGUF_TYPE_STRING,  0}, // undefined
+    {GGUF_TYPE_ARRAY,   0}, // undefined
+    {GGUF_TYPE_UINT64,  sizeof(uint64_t)},
+    {GGUF_TYPE_INT64,   sizeof(int64_t)},
+    {GGUF_TYPE_FLOAT64, sizeof(double)},
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+static const std::map GGUF_TYPE_NAME = {
+    {GGUF_TYPE_UINT8,   "u8"},
+    {GGUF_TYPE_INT8,    "i8"},
+    {GGUF_TYPE_UINT16,  "u16"},
+    {GGUF_TYPE_INT16,   "i16"},
+    {GGUF_TYPE_UINT32,  "u32"},
+    {GGUF_TYPE_INT32,   "i32"},
+    {GGUF_TYPE_FLOAT32, "f32"},
+    {GGUF_TYPE_BOOL,    "bool"},
+    {GGUF_TYPE_STRING,  "str"},
+    {GGUF_TYPE_ARRAY,   "arr"},
+    {GGUF_TYPE_UINT64,  "u64"},
+    {GGUF_TYPE_INT64,   "i64"},
+    {GGUF_TYPE_FLOAT64, "f64"},
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+size_t gguf_type_size(enum gguf_type type) {
+    auto it = GGUF_TYPE_SIZE.find(type);
+    return it == GGUF_TYPE_SIZE.end() ? 0 : it->second;
+}
+
+struct gguf_kv {
+    std::string key;
+
+    bool is_array;
+    enum gguf_type type;
+
+    std::vector      data;
+    std::vector data_string;
+
+    template 
+    gguf_kv(const std::string & key, const T value)
+            : key(key), is_array(false), type(type_to_gguf_type::value) {
+        GGML_ASSERT(!key.empty());
+        data.resize(sizeof(T));
+        memcpy(data.data(), &value, sizeof(T));
+    }
+
+    template 
+    gguf_kv(const std::string & key, const std::vector & value)
+            : key(key), is_array(true), type(type_to_gguf_type::value) {
+        GGML_ASSERT(!key.empty());
+        data.resize(value.size()*sizeof(T));
+        for (size_t i = 0; i < value.size(); ++i) {
+            const T tmp = value[i];
+            memcpy(data.data() + i*sizeof(T), &tmp, sizeof(T));
+        }
+    }
+
+    gguf_kv(const std::string & key, const std::string & value)
+            : key(key), is_array(false), type(GGUF_TYPE_STRING) {
+        GGML_ASSERT(!key.empty());
+        data_string.push_back(value);
+    }
+
+    gguf_kv(const std::string & key, const std::vector & value)
+            : key(key), is_array(true), type(GGUF_TYPE_STRING) {
+        GGML_ASSERT(!key.empty());
+        data_string = value;
+    }
+
+    const std::string & get_key() const {
+        return key;
+    }
+
+    const enum gguf_type & get_type() const {
+        return type;
+    }
+
+    size_t get_ne() const {
+        if (type == GGUF_TYPE_STRING) {
+            const size_t ne = data_string.size();
+            GGML_ASSERT(is_array || ne == 1);
+            return ne;
+        }
+        const size_t type_size = gguf_type_size(type);
+        GGML_ASSERT(data.size() % type_size == 0);
+        const size_t ne = data.size() / type_size;
+        GGML_ASSERT(is_array || ne == 1);
+        return ne;
+    }
+
+    template 
+    const T & get_val(const size_t i = 0) const {
+        GGML_ASSERT(type_to_gguf_type::value == type);
+        if constexpr (std::is_same::value) {
+            GGML_ASSERT(data_string.size() >= i+1);
+            return data_string[i];
+        }
+        const size_t type_size = gguf_type_size(type);
+        GGML_ASSERT(data.size() % type_size == 0);
+        GGML_ASSERT(data.size() >= (i+1)*type_size);
+        return reinterpret_cast(data.data())[i];
+    }
+
+    void cast(const enum gguf_type new_type) {
+        const size_t new_type_size = gguf_type_size(new_type);
+        GGML_ASSERT(data.size() % new_type_size == 0);
+        type = new_type;
+    }
+};
+
+struct gguf_tensor_info {
+    struct ggml_tensor t; // for holding the equivalent info
+    uint64_t offset;      // offset from start of `data`, must be a multiple of `ALIGNMENT`
+};
+
+struct gguf_context {
+    uint32_t version = GGUF_VERSION;
+
+    std::vector kv;
+    std::vector info;
+
+    size_t alignment = GGUF_DEFAULT_ALIGNMENT;
+    size_t offset    = 0; // offset of `data` from beginning of file
+    size_t size      = 0; // size of `data` in bytes
+
+    void * data = nullptr;
+};
+
+struct gguf_reader {
+    FILE * file;
+
+    gguf_reader(FILE * file) : file(file) {}
+
+    template 
+    bool read(T & dst) const {
+        return fread(&dst, 1, sizeof(dst), file) == sizeof(dst);
+    }
+
+    template 
+    bool read(std::vector & dst, const size_t n) const {
+        dst.resize(n);
+        for (size_t i = 0; i < dst.size(); ++i) {
+            if constexpr (std::is_same::value) {
+                bool tmp;
+                if (!read(tmp)) {
+                    return false;
+                }
+                dst[i] = tmp;
+            } else {
+                if (!read(dst[i])) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
+    bool read(bool & dst) const {
+        int8_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = tmp != 0;
+        return true;
+    }
+
+    bool read(enum ggml_type & dst) const {
+        int32_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = ggml_type(tmp);
+        return true;
+    }
+
+    bool read(enum gguf_type & dst) const {
+        int32_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = gguf_type(tmp);
+        return true;
+    }
+
+    bool read(std::string & dst) const {
+        uint64_t size = -1;
+        if (!read(size)) {
+            return false;
+        }
+        dst.resize(size);
+        return fread(dst.data(), 1, dst.length(), file) == dst.length();
+    }
+
+    bool read(void * dst, const size_t size) const {
+        return fread(dst, 1, size, file) == size;
+    }
+};
+
+struct gguf_context * gguf_init_empty(void) {
+    return new gguf_context;
+}
+
+template
+bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector & kv, const std::string & key, const bool is_array, const size_t n) {
+    if (is_array) {
+        std::vector value;
+        try {
+            if (!gr.read(value, n)) {
+                return false;
+            }
+        } catch (std::length_error &) {
+            fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
+            return false;
+        } catch (std::bad_alloc &) {
+            fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
+            return false;
+        }
+        kv.emplace_back(key, value);
+    } else {
+        T value;
+        if (!gr.read(value)) {
+            return false;
+        }
+        kv.emplace_back(key, value);
+    }
+    return true;
+}
+
+struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
+    const struct gguf_reader gr(file);
+    struct gguf_context * ctx = new gguf_context;
+
+    bool ok = true;
+
+    // file magic
+    {
+        std::vector magic;
+        ok = ok && gr.read(magic, 4);
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read magic\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        for (uint32_t i = 0; i < magic.size(); i++) {
+            if (magic[i] != GGUF_MAGIC[i]) {
+                fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
+                gguf_free(ctx);
+                return nullptr;
+            }
+        }
+    }
+
+    // header
+    int64_t n_kv      = 0;
+    int64_t n_tensors = 0;
+
+    if (ok && gr.read(ctx->version)) {
+        if (ctx->version == 1) {
+            fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__);
+            ok = false;
+        }
+        if (ctx->version > GGUF_VERSION) {
+            fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n",
+                __func__, ctx->version, GGUF_VERSION);
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (ok && gr.read(n_tensors)) {
+        static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
+        if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {
+            fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
+                __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (ok && gr.read(n_kv)) {
+        static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
+        if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {
+            fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
+                    __func__, n_kv, SIZE_MAX/sizeof(gguf_kv));
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (!ok) {
+        fprintf(stderr, "%s: failed to read header\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+
+    // KV pairs
+    {
+        for (int64_t i = 0; ok && i < n_kv; ++i) {
+            std::string key;
+            gguf_type   type     = gguf_type(-1);
+            bool        is_array = false;
+            uint64_t    n        = 1;
+
+            try {
+                ok = ok && gr.read(key);
+            } catch (std::length_error &) {
+                fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
+                ok = false;
+            } catch (std::bad_alloc &) {
+                fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
+                ok = false;
+            }
+            for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {
+                if (key == ctx->kv[j].key) {
+                    fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
+                    ok = false;
+                }
+            }
+            if (!ok) {
+                break;
+            }
+
+            ok = ok && gr.read(type);
+            if (type == GGUF_TYPE_ARRAY) {
+                is_array = true;
+                ok = ok && gr.read(type);
+                ok = ok && gr.read(n);
+            }
+            if (!ok) {
+                break;
+            }
+
+            switch (type) {
+                case GGUF_TYPE_UINT8:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT8:    ok = ok && gguf_read_emplace_helper     (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT16:  ok = ok && gguf_read_emplace_helper   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT16:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT32:  ok = ok && gguf_read_emplace_helper   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT32:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_FLOAT32: ok = ok && gguf_read_emplace_helper      (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_BOOL:    ok = ok && gguf_read_emplace_helper       (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_STRING:  ok = ok && gguf_read_emplace_helper(gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT64:  ok = ok && gguf_read_emplace_helper   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT64:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_FLOAT64: ok = ok && gguf_read_emplace_helper     (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_ARRAY:
+                default:
+                    {
+                        fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
+                        ok = false;
+                    } break;
+            }
+        }
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+        GGML_ASSERT(int64_t(ctx->kv.size()) == n_kv);
+
+        const int alignment_idx = gguf_find_key(ctx, GGUF_KEY_GENERAL_ALIGNMENT);
+        ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);
+
+        if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {
+            fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
+            gguf_free(ctx);
+            return nullptr;
+        }
+    }
+
+    // read the tensor info
+    for (int64_t i = 0; ok && i < n_tensors; ++i) {
+        struct gguf_tensor_info info;
+
+        // tensor name
+        {
+            std::string name;
+            try {
+                ok = ok && gr.read(name);
+            } catch (std::length_error &) {
+                fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
+                ok = false;
+            } catch (std::bad_alloc &) {
+                fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
+                ok = false;
+            }
+            if (name.length() >= GGML_MAX_NAME) {
+                fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
+                ok = false;
+                break;
+            }
+            ggml_set_name(&info.t, name.c_str());
+
+            // make sure there are no duplicate tensor names
+            for (int64_t j = 0; ok && j < i; ++j) {
+                if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {
+                    fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
+                    ok = false;
+                    break;
+                }
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor shape
+        {
+            uint32_t n_dims = -1;
+            ok = ok && gr.read(n_dims);
+            if (n_dims > GGML_MAX_DIMS) {
+                fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
+                    __func__, info.t.name, n_dims, GGML_MAX_DIMS);
+                ok = false;
+                break;
+            }
+            for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) {
+                info.t.ne[j] = 1;
+                if (j < n_dims) {
+                    ok = ok && gr.read(info.t.ne[j]);
+                }
+
+                // check that all ne are non-negative
+                if (info.t.ne[j] < 0) {
+                    fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
+                        __func__, info.t.name, j, info.t.ne[j]);
+                    ok = false;
+                    break;
+                }
+            }
+
+            // check that the total number of elements is representable
+            if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) ||
+                       (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
+                       (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {
+
+                fprintf(stderr, "%s: total number of elements in tensor '%s' with shape "
+                    "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n",
+                    __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);
+                ok = false;
+                break;
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor type
+        {
+            ok = ok && gr.read(info.t.type);
+
+            // check that tensor type is within defined range
+            if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
+                fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n",
+                    __func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
+                ok = false;
+                break;
+            }
+            const size_t  type_size = ggml_type_size(info.t.type);
+            const int64_t blck_size = ggml_blck_size(info.t.type);
+
+            // check that row size is divisible by block size
+            if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {
+                fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
+                    "not a multiple of block size (%" PRId64 ")\n",
+                    __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);
+                ok = false;
+                break;
+            }
+
+            // calculate byte offsets given the tensor shape and type
+            info.t.nb[0] = type_size;
+            info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size);
+            for (int j = 2; j < GGML_MAX_DIMS; ++j) {
+                info.t.nb[j] = info.t.nb[j - 1]*info.t.ne[j - 1];
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor data offset within buffer
+        ok = ok && gr.read(info.offset);
+
+        ctx->info.push_back(info);
+    }
+
+    if (!ok) {
+        fprintf(stderr, "%s: failed to read tensor info\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+    GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors);
+
+    // we require the data section to be aligned, so take into account any padding
+    if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
+        fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+
+    // store the current file offset - this is where the data section starts
+    ctx->offset = ftell(file);
+
+    // compute the total size of the data section, taking into account the alignment
+    {
+        ctx->size = 0;
+        for (size_t i = 0; i < ctx->info.size(); ++i) {
+            const gguf_tensor_info & ti = ctx->info[i];
+            if (ti.offset != ctx->size) {
+                fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
+                    __func__, ti.t.name, ti.offset, ctx->size);
+                fprintf(stderr, "%s: failed to read tensor data\n", __func__);
+                gguf_free(ctx);
+                return nullptr;
+            }
+            ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
+        }
+    }
+
+    // load the tensor data only if requested
+    if (params.ctx != nullptr) {
+        // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
+        // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
+        //   the ggml_tensor structs to the appropriate locations in the binary blob
+
+        // compute the exact size needed for the new ggml_context
+        const size_t mem_size =
+            params.no_alloc ?
+            (n_tensors    )*ggml_tensor_overhead() :
+            (n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
+
+        struct ggml_init_params pdata = {
+            /*mem_size   =*/ mem_size,
+            /*mem_buffer =*/ nullptr,
+            /*no_alloc   =*/ params.no_alloc,
+        };
+
+        *params.ctx = ggml_init(pdata);
+        if (*params.ctx == nullptr) {
+            fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        struct ggml_context * ctx_data = *params.ctx;
+
+        struct ggml_tensor * data = nullptr;
+
+        if (!params.no_alloc) {
+            data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
+
+            ok = ok && data != nullptr;
+
+            if (ok) {
+                ggml_set_name(data, "GGUF tensor data binary blob");
+            }
+
+            // read the binary blob with the tensor data
+            ok = ok && gr.read(data->data, ctx->size);
+
+            if (!ok) {
+                fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__);
+                ggml_free(ctx_data);
+                *params.ctx = nullptr;
+                gguf_free(ctx);
+                return nullptr;
+            }
+
+            ctx->data = data->data;
+        }
+
+        ggml_set_no_alloc(ctx_data, true);
+
+        // create the tensors
+        for (size_t i = 0; i < ctx->info.size(); ++i) {
+            const struct gguf_tensor_info & info = ctx->info[i];
+
+            struct ggml_tensor * cur = ggml_new_tensor(ctx_data, info.t.type, GGML_MAX_DIMS, info.t.ne);
+
+            ok = ok && cur != nullptr;
+
+            if (!ok) {
+                break;
+            }
+
+            ggml_set_name(cur, info.t.name);
+
+            // point the data member to the appropriate location in the binary blob using the tensor info
+            if (!params.no_alloc) {
+                cur->data = (char *) data->data + info.offset;
+            }
+        }
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to create tensors\n", __func__);
+            ggml_free(ctx_data);
+            *params.ctx = nullptr;
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        ggml_set_no_alloc(ctx_data, params.no_alloc);
+    }
+
+    return ctx;
+}
+
+struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
+    FILE * file = ggml_fopen(fname, "rb");
+
+    if (!file) {
+        fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname);
+        return nullptr;
+    }
+
+    struct gguf_context * result = gguf_init_from_file_impl(file, params);
+    fclose(file);
+    return result;
+}
+
+void gguf_free(struct gguf_context * ctx) {
+    if (ctx == nullptr) {
+        return;
+    }
+    delete ctx;
+}
+
+const char * gguf_type_name(enum gguf_type type) {
+    auto it = GGUF_TYPE_NAME.find(type);
+    return it == GGUF_TYPE_NAME.end() ? nullptr : it->second;
+}
+
+uint32_t gguf_get_version(const struct gguf_context * ctx) {
+    return ctx->version;
+}
+
+size_t gguf_get_alignment(const struct gguf_context * ctx) {
+    return ctx->alignment;
+}
+
+size_t gguf_get_data_offset(const struct gguf_context * ctx) {
+    return ctx->offset;
+}
+
+int64_t gguf_get_n_kv(const struct gguf_context * ctx) {
+    return ctx->kv.size();
+}
+
+int64_t gguf_find_key(const struct gguf_context * ctx, const char * key) {
+    // return -1 if key not found
+    int64_t keyfound = -1;
+
+    const int64_t n_kv = gguf_get_n_kv(ctx);
+
+    for (int64_t i = 0; i < n_kv; ++i) {
+        if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
+            keyfound = i;
+            break;
+        }
+    }
+
+    return keyfound;
+}
+
+const char * gguf_get_key(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    return ctx->kv[key_id].get_key().c_str();
+}
+
+enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    return ctx->kv[key_id].is_array ? GGUF_TYPE_ARRAY : ctx->kv[key_id].get_type();
+}
+
+enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].is_array);
+    return ctx->kv[key_id].get_type();
+}
+
+const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data.data();
+}
+
+const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data_string[i].c_str();
+}
+
+size_t gguf_get_arr_n(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+
+    if (ctx->kv[key_id].type == GGUF_TYPE_STRING) {
+        return ctx->kv[key_id].data_string.size();
+    }
+
+    const size_t type_size = gguf_type_size(ctx->kv[key_id].type);
+    GGML_ASSERT(ctx->kv[key_id].data.size() % type_size == 0);
+    return ctx->kv[key_id].data.size() / type_size;
+}
+
+uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int8_t gguf_get_val_i8(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int16_t gguf_get_val_i16(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int32_t gguf_get_val_i32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+float gguf_get_val_f32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int64_t gguf_get_val_i64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+double gguf_get_val_f64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val().c_str();
+}
+
+const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data.data();
+}
+
+int64_t gguf_get_n_tensors(const struct gguf_context * ctx) {
+    return ctx->info.size();
+}
+
+int64_t gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
+    // return -1 if tensor not found
+    int64_t tensor_id = -1;
+
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
+            tensor_id = i;
+            break;
+        }
+    }
+
+    return tensor_id;
+}
+
+size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].offset;
+}
+
+const char * gguf_get_tensor_name(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].t.name;
+}
+
+enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].t.type;
+}
+
+size_t gguf_get_tensor_size(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ggml_nbytes(&ctx->info[tensor_id].t);
+}
+
+int64_t gguf_remove_key(struct gguf_context * ctx, const char * key) {
+    const int64_t key_id = gguf_find_key(ctx, key);
+    if (key_id >= 0) {
+        ctx->kv.erase(ctx->kv.begin() + key_id);
+    }
+    return key_id;
+}
+
+template
+static void gguf_check_reserved_keys(const std::string & key, const T val) {
+    if (key == GGUF_KEY_GENERAL_ALIGNMENT) {
+        if constexpr (std::is_same::value) {
+            GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT " must be power of 2");
+        } else {
+            GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT " must be type u32");
+        }
+    }
+}
+
+void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, std::string(val));
+}
+
+void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n) {
+    gguf_check_reserved_keys(key, data);
+    gguf_remove_key(ctx, key);
+
+    const size_t nbytes = n*gguf_type_size(type);
+    std::vector tmp(nbytes);
+    if (!tmp.empty()) {
+        memcpy(tmp.data(), data, nbytes);
+    }
+    ctx->kv.emplace_back(key, tmp);
+    ctx->kv.back().cast(type);
+}
+
+void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, size_t n) {
+    gguf_check_reserved_keys(key, data);
+    gguf_remove_key(ctx, key);
+
+    std::vector tmp(n);
+    for (size_t i = 0; i < n; ++i) {
+        tmp[i] = data[i];
+    }
+    ctx->kv.emplace_back(key, tmp);
+}
+
+// set or add KV pairs from another context
+void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src) {
+    const int64_t n_kv = gguf_get_n_kv(src);
+    for (int64_t i = 0; i < n_kv; ++i) {
+        const struct gguf_kv & kv = src->kv[i];
+
+        if (!kv.is_array) {
+            switch (kv.get_type()) {
+                case GGUF_TYPE_UINT8:   gguf_set_val_u8  (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_INT8:    gguf_set_val_i8  (ctx, kv.get_key().c_str(), kv.get_val());              break;
+                case GGUF_TYPE_UINT16:  gguf_set_val_u16 (ctx, kv.get_key().c_str(), kv.get_val());            break;
+                case GGUF_TYPE_INT16:   gguf_set_val_i16 (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, kv.get_key().c_str(), kv.get_val());            break;
+                case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, kv.get_key().c_str(), kv.get_val());               break;
+                case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, kv.get_key().c_str(), kv.get_val());            break;
+                case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, kv.get_key().c_str(), kv.get_val());              break;
+                case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, kv.get_key().c_str(), kv.get_val());                break;
+                case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, kv.get_key().c_str(), kv.get_val().c_str()); break;
+                case GGUF_TYPE_ARRAY:
+                default: GGML_ABORT("invalid type");
+            }
+            continue;
+        }
+
+        const size_t ne = kv.get_ne();
+
+        switch (kv.get_type()) {
+            case GGUF_TYPE_UINT8:
+            case GGUF_TYPE_INT8:
+            case GGUF_TYPE_UINT16:
+            case GGUF_TYPE_INT16:
+            case GGUF_TYPE_UINT32:
+            case GGUF_TYPE_INT32:
+            case GGUF_TYPE_FLOAT32:
+            case GGUF_TYPE_UINT64:
+            case GGUF_TYPE_INT64:
+            case GGUF_TYPE_FLOAT64:
+            case GGUF_TYPE_BOOL: {
+                gguf_set_arr_data(ctx, kv.get_key().c_str(), kv.get_type(), kv.data.data(), ne);
+            } break;
+            case GGUF_TYPE_STRING: {
+                std::vector tmp(ne);
+                for (size_t j = 0; j < ne; ++j) {
+                    tmp[j] = kv.data_string[j].c_str();
+                }
+                gguf_set_arr_str(ctx, kv.get_key().c_str(), tmp.data(), ne);
+            } break;
+            case GGUF_TYPE_ARRAY:
+            default: GGML_ABORT("invalid type");
+        }
+    }
+}
+
+void gguf_add_tensor(
+             struct gguf_context * ctx,
+        const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
+    if (gguf_find_tensor(ctx, tensor->name) != -1) {
+        GGML_ABORT("duplicate tensor name: %s", tensor->name);
+    }
+
+    struct gguf_tensor_info ti;
+    ti.t = *tensor;
+    ti.offset = ctx->info.empty() ? 0 :
+        ctx->info.back().offset + GGML_PAD(ggml_nbytes(&ctx->info.back().t), ctx->alignment);
+    ctx->info.push_back(ti);
+}
+
+void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
+    const int64_t tensor_id = gguf_find_tensor(ctx, name);
+    if (tensor_id < 0) {
+        GGML_ABORT("tensor not found: %s", name);
+    }
+    struct ggml_tensor * tensor = &ctx->info[tensor_id].t;
+    const size_t  type_size = ggml_type_size(type);
+    const int64_t blck_size = ggml_blck_size(type);
+
+    tensor->type = type;
+    GGML_ASSERT(tensor->ne[0] % blck_size == 0 && "tensor row size not divisible by block size of new type");
+
+    tensor->nb[0] = type_size;
+    tensor->nb[1] = tensor->nb[0]*(tensor->ne[0]/blck_size);
+    for (int i = 2; i < GGML_MAX_DIMS; i++) {
+        tensor->nb[i] = tensor->nb[i - 1]*tensor->ne[i - 1];
+    }
+
+    // update offsets
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+    for (int64_t i = tensor_id + 1; i < n_tensors; ++i) {
+        ctx->info[i].offset = ctx->info[i - 1].offset + GGML_PAD(ggml_nbytes(&ctx->info[i - 1].t), ctx->alignment);
+    }
+}
+
+void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data) {
+    const int64_t tensor_id = gguf_find_tensor(ctx, name);
+    if (tensor_id < 0) {
+        GGML_ABORT("tensor not found: %s", name);
+    }
+
+    ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const
+}
+
+struct gguf_writer {
+    std::vector & buf;
+
+    gguf_writer(std::vector & buf) : buf(buf) {}
+
+    template 
+    void write(const T & val) const {
+        for (size_t i = 0; i < sizeof(val); ++i) {
+            buf.push_back(reinterpret_cast(&val)[i]);
+        }
+    }
+
+    void write(const std::vector & val) const {
+        buf.insert(buf.end(), val.begin(), val.end());
+    }
+
+    void write(const bool & val) const {
+        const int8_t val8 = val ? 1 : 0;
+        write(val8);
+    }
+
+    void write(const std::string & val) const {
+        {
+            const uint64_t n = val.length();
+            write(n);
+        }
+        for (size_t i = 0; i < val.length(); ++i) {
+            buf.push_back(reinterpret_cast(val.data())[i]);
+        }
+    }
+
+    void write(const char * val) const {
+        write(std::string(val));
+    }
+
+    void write(const enum ggml_type & val) const {
+        write(int32_t(val));
+    }
+
+    void write(const enum gguf_type & val) const {
+        write(int32_t(val));
+    }
+
+    void write(const struct gguf_kv & kv) const {
+        const uint64_t ne = kv.get_ne();
+
+        write(kv.get_key());
+
+        if (kv.is_array) {
+            write(GGUF_TYPE_ARRAY);
+            write(kv.get_type());
+            write(ne);
+        } else {
+            write(kv.get_type());
+        }
+
+        switch (kv.get_type()) {
+            case GGUF_TYPE_UINT8:
+            case GGUF_TYPE_INT8:
+            case GGUF_TYPE_UINT16:
+            case GGUF_TYPE_INT16:
+            case GGUF_TYPE_UINT32:
+            case GGUF_TYPE_INT32:
+            case GGUF_TYPE_FLOAT32:
+            case GGUF_TYPE_UINT64:
+            case GGUF_TYPE_INT64:
+            case GGUF_TYPE_FLOAT64: {
+                write(kv.data);
+            } break;
+            case GGUF_TYPE_BOOL: {
+                for (size_t i = 0; i < ne; ++i) {
+                    write(kv.get_val(i));
+                }
+            } break;
+            case GGUF_TYPE_STRING: {
+                for (size_t i = 0; i < ne; ++i) {
+                    write(kv.get_val(i));
+                }
+            } break;
+            case GGUF_TYPE_ARRAY:
+            default: GGML_ABORT("invalid type");
+        }
+    }
+
+    void write_tensor_meta(const struct gguf_tensor_info & info) const {
+        write(info.t.name);
+
+        const uint32_t n_dims = ggml_n_dims(&info.t);
+        write(n_dims);
+
+        for (uint32_t j = 0; j < n_dims; ++j) {
+            write(info.t.ne[j]);
+        }
+        write(info.t.type);
+        write(info.offset);
+    }
+
+    void pad(const size_t alignment) const {
+        while (buf.size() % alignment != 0) {
+            const int8_t zero = 0;
+            write(zero);
+        }
+    }
+
+    void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const {
+        GGML_ASSERT(buf.size() - offset_data == info.offset);
+
+        GGML_ASSERT(ggml_is_contiguous(&info.t));
+        const size_t offset = buf.size();
+        const size_t nbytes = ggml_nbytes(&info.t);
+
+        buf.resize(offset + nbytes);
+        if (info.t.buffer) {
+            ggml_backend_tensor_get(&info.t, buf.data() + offset, 0, nbytes);
+        } else {
+            GGML_ASSERT(info.t.data);
+            memcpy(buf.data() + offset, info.t.data, nbytes);
+        }
+
+        pad(alignment);
+    }
+};
+
+void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) {
+    const struct gguf_writer gw(buf);
+
+    const int64_t n_kv      = gguf_get_n_kv(ctx);
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+
+    // write header
+    gw.write(GGUF_MAGIC[0]);
+    gw.write(GGUF_MAGIC[1]);
+    gw.write(GGUF_MAGIC[2]);
+    gw.write(GGUF_MAGIC[3]);
+    gw.write(ctx->version);
+    gw.write(n_tensors);
+    gw.write(n_kv);
+
+    // write key-value pairs
+    for (int64_t i = 0; i < n_kv; ++i) {
+        gw.write(ctx->kv[i]);
+    }
+
+    // write tensor info
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        gw.write_tensor_meta(ctx->info[i]);
+    }
+
+    // we require the data section to be aligned
+    gw.pad(ctx->alignment);
+
+    if (only_meta) {
+        return;
+    }
+
+    const size_t offset_data = gw.buf.size();
+
+    // write tensor data
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        gw.write_tensor_data(ctx->info[i], offset_data, ctx->alignment);
+    }
+}
+
+bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
+    FILE * file = ggml_fopen(fname, "wb");
+
+    if (!file) {
+        fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
+        return false;
+    }
+
+    std::vector buf;
+    gguf_write_to_buf(ctx, buf, only_meta);
+    const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size();
+    fclose(file);
+    return ok;
+}
+
+size_t gguf_get_meta_size(const struct gguf_context * ctx) {
+    // only return size
+    std::vector buf;
+    gguf_write_to_buf(ctx, buf, /*only_meta =*/ true);
+    return buf.size();
+}
+
+void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
+    std::vector buf;
+    gguf_write_to_buf(ctx, buf, /*only_meta =*/ true);
+    memcpy(data, buf.data(), buf.size());
+}

From a5272130c4da24615a7428f5e7982a17ba64a6d8 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Wed, 26 Feb 2025 22:33:53 -0800
Subject: [PATCH 04/20] ml/backend/ggml: follow on fixes after updating
 vendored code (#9388)

Fixes sync filters and lowers CUDA version to 11.3 in test.yaml
---
 .github/workflows/test.yaml                   |  6 +--
 ml/backend/ggml/ggml/.rsync-filter            |  1 +
 .../ggml/ggml/src/ggml-cuda/vendors/cuda.h    |  1 +
 .../ggml/ggml/src/ggml-cuda/vendors/hip.h     | 46 +++++++++++++++++++
 4 files changed, 51 insertions(+), 3 deletions(-)

diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index bb0e8d90..479a9bb8 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -78,8 +78,8 @@ jobs:
         include:
           - preset: CPU
           - preset: CUDA
-            install: https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe
-            flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
+            install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
+            flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
           - preset: ROCm
             install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
             flags: '-DAMDGPU_TARGETS=gfx1010'
@@ -102,7 +102,7 @@ jobs:
           $ErrorActionPreference = "Stop"
           if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
             Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
-            Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.8", "nvcc_11.8", "cublas_11.8", "cublas_dev_11.8")) -NoNewWindow -Wait
+            Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
           }
 
           $cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
diff --git a/ml/backend/ggml/ggml/.rsync-filter b/ml/backend/ggml/ggml/.rsync-filter
index c5acbe49..ddad16e2 100644
--- a/ml/backend/ggml/ggml/.rsync-filter
+++ b/ml/backend/ggml/ggml/.rsync-filter
@@ -9,6 +9,7 @@ include src/ggml-cpu/
 include src/ggml-cpu/amx/
 include src/ggml-cpu/llamafile/
 include src/ggml-cuda/
+include src/ggml-cuda/vendors/
 include src/ggml-cuda/template-instances/
 include src/ggml-hip/
 include src/ggml-metal/
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h
index db9f6a16..1746b073 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/cuda.h
@@ -3,6 +3,7 @@
 #include 
 #include 
 #include 
+#include 
 #include 
 
 #if CUDART_VERSION < 11020
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
index c905b15d..81964611 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h
@@ -1,5 +1,6 @@
 #pragma once
 
+#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
 #include 
 #include 
 #include 
@@ -8,6 +9,7 @@
 // for rocblas_initialize()
 #include "rocblas/rocblas.h"
 #endif // __HIP_PLATFORM_AMD__
+
 #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
 #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
 #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
@@ -19,6 +21,13 @@
 #define CUBLAS_TF32_TENSOR_OP_MATH 0
 #define CUDA_R_16F  HIPBLAS_R_16F
 #define CUDA_R_32F  HIPBLAS_R_32F
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
+#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
+#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
+#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
+#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
 #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
 #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
 #define cublasCreate hipblasCreate
@@ -74,6 +83,21 @@
 #define cudaMemGetInfo hipMemGetInfo
 #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
 #define cudaSetDevice hipSetDevice
+#define cuDeviceGet hipDeviceGet
+#define CUdevice hipDevice_t
+#define CUdeviceptr hipDeviceptr_t
+#define cuMemUnmap hipMemUnmap
+#define CUmemAccessDesc hipMemAccessDesc
+#define cuMemAddressFree hipMemAddressFree
+#define cuMemRelease hipMemRelease
+#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
+#define cuMemCreate hipMemCreate
+#define cuMemAddressReserve hipMemAddressReserve
+#define cuMemMap hipMemMap
+#define cuMemSetAccess hipMemSetAccess
+#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
+#define CUmemAllocationProp hipMemAllocationProp
+#define cuDeviceGetAttribute hipDeviceGetAttribute
 #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
 #define cudaStreamDestroy hipStreamDestroy
 #define cudaStreamFireAndForget hipStreamFireAndForget
@@ -81,6 +105,28 @@
 #define cudaStreamPerThread hipStreamPerThread
 #define cudaStreamSynchronize hipStreamSynchronize
 #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
+#define cudaGraphExec_t hipGraphExec_t
+#define cudaGraphNode_t hipGraphNode_t
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaGraphExecDestroy hipGraphExecDestroy
+#define cudaGraphLaunch hipGraphLaunch
+#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
+#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
+#define cudaGraphNodeType hipGraphNodeType
+#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
+#define cudaGraphInstantiate hipGraphInstantiate
+#define cudaStreamEndCapture hipStreamEndCapture
+#define cudaGraphDestroy hipGraphDestroy
+#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
+#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
+#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
+#define cudaGraphNodeGetType hipGraphNodeGetType
+#define cudaGraphGetNodes hipGraphGetNodes
+#define cudaGraphExecUpdate hipGraphExecUpdate
+#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
+#define cudaStreamBeginCapture hipStreamBeginCapture
+#define cudaGraph_t hipGraph_t
 #define cudaStream_t hipStream_t
 #define cudaSuccess hipSuccess
 #define __trap() do { abort(); __builtin_unreachable(); } while(0)

From 76e903cf9db33f3f3965596896f47c1186496c4d Mon Sep 17 00:00:00 2001
From: Blake Mizerany 
Date: Wed, 26 Feb 2025 23:03:48 -0800
Subject: [PATCH 05/20] .github/workflows: swap order of go test and
 golangci-lint (#9389)

The linter is secondary to the tests, so it should run after the tests,
exposing test failures faster.
---
 .github/workflows/test.yaml | 29 ++++++++++++++---------------
 1 file changed, 14 insertions(+), 15 deletions(-)

diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index 479a9bb8..e7e47c96 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -190,6 +190,20 @@ jobs:
 
           go-version-file: go.mod
 
+      # It is tempting to run this in a platform independent way, but the past
+      # shows this codebase will see introductions of platform specific code
+      # generation, and so we need to check this per platform to ensure we
+      # don't abuse go generate on specific platforms.
+      - name: check that 'go generate' is clean
+        if: always()
+        run: |
+          go generate ./...
+          git diff --name-only --exit-code || (echo "Please run 'go generate ./...'." && exit 1)
+
+      - name: go test
+        if: always()
+        run: go test -count=1 -benchtime=1x ./...
+
       # TODO(bmizerany): replace this heavy tool with just the
       # tools/checks/binaries we want and then make them all run in parallel
       # across jobs, not on a single tiny vm on Github Actions.
@@ -197,21 +211,6 @@ jobs:
         with:
           args: --timeout 10m0s -v
 
-      - name: go test
-        # Do not skip tests in the face of linter errors, or 'go mod tidy'
-        # checks, which are secondary to the tests. Tests trump linters.
-        if: always()
-        run: go test -count=1 -benchtime=1x ./...
-
-      # It is tempting to run this in a platform independent way, but the past
-      # shows this codebase will see introductions of platform specific code
-      # generation, and so we need to check this per platform to ensure we
-      # don't abuse go generate on specific platforms.
-      - name: check that 'go generate' is clean
-        run: |
-          go generate ./...
-          git diff --name-only --exit-code || (echo "Please run 'go generate ./...'." && exit 1)
-
       - name: cache save
         # Always save the cache, even if the job fails. The artifacts produced
         # during the building of test binaries are not all for naught. They can

From 688925aca9a4747098ee331c3ecb702907fbca23 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen 
Date: Thu, 27 Feb 2025 09:02:25 -0800
Subject: [PATCH 06/20] Windows ARM build (#9120)

* Windows ARM build

Skip cmake, and note it's unused in the developer docs.

* Win: only check for ninja when we need it

On windows ARM, the cim lookup fails, but we don't need ninja anyway.
---
 docs/development.md       |  2 +-
 scripts/build_windows.ps1 | 12 ++++--------
 2 files changed, 5 insertions(+), 9 deletions(-)

diff --git a/docs/development.md b/docs/development.md
index 6e68c9eb..eb67dbfa 100644
--- a/docs/development.md
+++ b/docs/development.md
@@ -69,7 +69,7 @@ go run . serve
 
 ## Windows (ARM)
 
-Windows ARM does not support additional acceleration libraries at this time.
+Windows ARM does not support additional acceleration libraries at this time.  Do not use cmake, simply `go run` or `go build`.
 
 ## Linux
 
diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1
index 312c3db5..62930d7f 100644
--- a/scripts/build_windows.ps1
+++ b/scripts/build_windows.ps1
@@ -26,9 +26,6 @@ function checkEnv() {
         $MSVC_INSTALL=(Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation
         $env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0]
     }
-    if (-Not (get-command -ErrorAction silent ninja)) {
-        $script:NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe) | split-path -parent
-    }
     # Locate CUDA versions
     # Note: this assumes every version found will be built
     $cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue')
@@ -83,7 +80,7 @@ function checkEnv() {
 
 
 function buildOllama() {
-    if ($null -eq ${env:OLLAMA_SKIP_GENERATE}) {
+    if ($script:ARCH -ne "arm64") {
         Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
         New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0
 
@@ -122,8 +119,9 @@ function buildOllama() {
         }
         if ($env:HIP_PATH) {
             write-host "Building ROCm backend libraries"
-            if ($null -ne $script:NINJA_DIR) {
-                $env:PATH="$script:NINJA_DIR;$env:PATH"
+            if (-Not (get-command -ErrorAction silent ninja)) {
+                $NINJA_DIR=(gci -path (Get-CimInstance MSFT_VSInstance -Namespace root/cimv2/vs)[0].InstallLocation -r -fi ninja.exe) | split-path -parent
+                $env:PATH="$NINJA_DIR;$env:PATH"
             }
             $env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe"
             $env:HIP_PLATFORM="amd"
@@ -138,8 +136,6 @@ function buildOllama() {
             & cmake --install build --component "HIP" --strip
             if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
         }
-    } else {
-        write-host "Skipping generate step with OLLAMA_SKIP_GENERATE set"
     }
     write-host "Building ollama CLI"
     & go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .

From a59f66523561dc195a37b78b454d8cd1b8b1fdd7 Mon Sep 17 00:00:00 2001
From: Michael Yang 
Date: Wed, 26 Feb 2025 17:00:25 -0800
Subject: [PATCH 07/20] ml/backend/ggml: fix debug logging

---
 llama/llama.go                   | 37 +++++++++++++-------------------
 ml/backend/ggml/ggml/src/ggml.go | 17 ++++++---------
 runner/llamarunner/runner.go     |  1 -
 3 files changed, 21 insertions(+), 34 deletions(-)

diff --git a/llama/llama.go b/llama/llama.go
index 6eed3d47..9add38c2 100644
--- a/llama/llama.go
+++ b/llama/llama.go
@@ -37,23 +37,36 @@ COMPILER inline get_compiler() {
 import "C"
 
 import (
+	"context"
 	_ "embed"
 	"errors"
 	"fmt"
+	"log/slog"
 	"os"
 	"runtime"
 	"runtime/cgo"
 	"slices"
 	"strings"
-	"sync/atomic"
 	"unsafe"
 
 	_ "github.com/ollama/ollama/llama/llama.cpp/common"
 	_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
 	_ "github.com/ollama/ollama/llama/llama.cpp/src"
-	"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
+	ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
 )
 
+func init() {
+	C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
+}
+
+//export llamaLog
+func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
+	// slog levels zeros INFO and are multiples of 4
+	if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
+		fmt.Fprint(os.Stderr, C.GoString(text))
+	}
+}
+
 func BackendInit() {
 	ggml.OnceLoad()
 	C.llama_backend_init()
@@ -72,26 +85,6 @@ func PrintSystemInfo() string {
 	return C.GoString(C.llama_print_system_info()) + compiler
 }
 
-var logLevel atomic.Int32
-
-func init() {
-	logLevel.Store(int32(C.GGML_LOG_LEVEL_INFO))
-	C.llama_log_set((C.ggml_log_callback)(C.llamaLog), nil)
-}
-
-func EnableDebug() {
-	logLevel.Store(int32(C.GGML_LOG_LEVEL_DEBUG))
-}
-
-//export llamaLog
-func llamaLog(level int32, text *C.char, _ unsafe.Pointer) {
-	if level < logLevel.Load() {
-		return
-	}
-
-	fmt.Fprint(os.Stderr, C.GoString(text))
-}
-
 func GetModelArch(modelPath string) (string, error) {
 	mp := C.CString(modelPath)
 	defer C.free(unsafe.Pointer(mp))
diff --git a/ml/backend/ggml/ggml/src/ggml.go b/ml/backend/ggml/ggml/src/ggml.go
index 3920e37d..85c693eb 100644
--- a/ml/backend/ggml/ggml/src/ggml.go
+++ b/ml/backend/ggml/ggml/src/ggml.go
@@ -10,6 +10,8 @@ package ggml
 import "C"
 
 import (
+	"context"
+	"fmt"
 	"log/slog"
 	"os"
 	"path/filepath"
@@ -22,21 +24,14 @@ import (
 )
 
 func init() {
-	C.ggml_log_set((C.ggml_log_callback)(C.sink), nil)
+	C.ggml_log_set(C.ggml_log_callback(C.sink), nil)
 }
 
 //export sink
 func sink(level C.int, text *C.char, _ unsafe.Pointer) {
-	msg := strings.TrimSpace(C.GoString(text))
-	switch level {
-	case C.GGML_LOG_LEVEL_DEBUG:
-		slog.Debug(msg)
-	case C.GGML_LOG_LEVEL_INFO:
-		slog.Info(msg)
-	case C.GGML_LOG_LEVEL_WARN:
-		slog.Warn(msg)
-	case C.GGML_LOG_LEVEL_ERROR:
-		slog.Error(msg)
+	// slog levels zeros INFO and are multiples of 4
+	if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
+		fmt.Fprint(os.Stderr, C.GoString(text))
 	}
 }
 
diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go
index 72873ec4..f9d20401 100644
--- a/runner/llamarunner/runner.go
+++ b/runner/llamarunner/runner.go
@@ -915,7 +915,6 @@ func Execute(args []string) error {
 	level := slog.LevelInfo
 	if *verbose {
 		level = slog.LevelDebug
-		llama.EnableDebug()
 	}
 	handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
 		Level:     level,

From d6af13efedb5c211589af0c2a6ad6fe1491bf302 Mon Sep 17 00:00:00 2001
From: Michael Yang 
Date: Wed, 26 Feb 2025 15:16:22 -0800
Subject: [PATCH 08/20] runner: simplify tensor split parsing

---
 runner/llamarunner/runner.go  | 9 ++++-----
 runner/ollamarunner/runner.go | 9 ++++-----
 2 files changed, 8 insertions(+), 10 deletions(-)

diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go
index f9d20401..1afc793e 100644
--- a/runner/llamarunner/runner.go
+++ b/runner/llamarunner/runner.go
@@ -943,12 +943,11 @@ func Execute(args []string) error {
 
 	var tensorSplitFloats []float32
 	if *tensorSplit != "" {
-		stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
-
-		tensorSplitFloats = make([]float32, 0, len(stringFloats))
-		for _, s := range stringFloats {
+		splits := strings.Split(*tensorSplit, ",")
+		tensorSplitFloats = make([]float32, len(splits))
+		for i, s := range splits {
 			f, _ := strconv.ParseFloat(s, 32)
-			tensorSplitFloats = append(tensorSplitFloats, float32(f))
+			tensorSplitFloats[i] = float32(f)
 		}
 	}
 
diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go
index d3998120..b39d747f 100644
--- a/runner/ollamarunner/runner.go
+++ b/runner/ollamarunner/runner.go
@@ -881,12 +881,11 @@ func Execute(args []string) error {
 
 	var tensorSplitFloats []float32
 	if *tensorSplit != "" {
-		stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
-
-		tensorSplitFloats = make([]float32, 0, len(stringFloats))
-		for _, s := range stringFloats {
+		splits := strings.Split(*tensorSplit, ",")
+		tensorSplitFloats = make([]float32, len(splits))
+		for i, s := range splits {
 			f, _ := strconv.ParseFloat(s, 32)
-			tensorSplitFloats = append(tensorSplitFloats, float32(f))
+			tensorSplitFloats[i] = float32(f)
 		}
 	}
 

From dc13813a03105bd76603a4909e31ba0c034d670d Mon Sep 17 00:00:00 2001
From: Eries Trisnadi 
Date: Fri, 28 Feb 2025 01:39:43 +0700
Subject: [PATCH 09/20] server: allow vscode-file origins (#9313)

---
 envconfig/config.go      | 1 +
 envconfig/config_test.go | 4 ++++
 2 files changed, 5 insertions(+)

diff --git a/envconfig/config.go b/envconfig/config.go
index 6117aa26..fc702198 100644
--- a/envconfig/config.go
+++ b/envconfig/config.go
@@ -73,6 +73,7 @@ func AllowedOrigins() (origins []string) {
 		"file://*",
 		"tauri://*",
 		"vscode-webview://*",
+		"vscode-file://*",
 	)
 
 	return origins
diff --git a/envconfig/config_test.go b/envconfig/config_test.go
index 385dab5f..5694eb8a 100644
--- a/envconfig/config_test.go
+++ b/envconfig/config_test.go
@@ -69,6 +69,7 @@ func TestOrigins(t *testing.T) {
 			"file://*",
 			"tauri://*",
 			"vscode-webview://*",
+			"vscode-file://*",
 		}},
 		{"http://10.0.0.1", []string{
 			"http://10.0.0.1",
@@ -88,6 +89,7 @@ func TestOrigins(t *testing.T) {
 			"file://*",
 			"tauri://*",
 			"vscode-webview://*",
+			"vscode-file://*",
 		}},
 		{"http://172.16.0.1,https://192.168.0.1", []string{
 			"http://172.16.0.1",
@@ -108,6 +110,7 @@ func TestOrigins(t *testing.T) {
 			"file://*",
 			"tauri://*",
 			"vscode-webview://*",
+			"vscode-file://*",
 		}},
 		{"http://totally.safe,http://definitely.legit", []string{
 			"http://totally.safe",
@@ -128,6 +131,7 @@ func TestOrigins(t *testing.T) {
 			"file://*",
 			"tauri://*",
 			"vscode-webview://*",
+			"vscode-file://*",
 		}},
 	}
 	for _, tt := range cases {

From be2ac1ed93db2fea3bf989da4628202414db7a96 Mon Sep 17 00:00:00 2001
From: Steven Hartland 
Date: Thu, 27 Feb 2025 18:51:12 +0000
Subject: [PATCH 10/20] docs: fix api examples link (#9360)

Fix the examples link in the go package documentation for the API.
---
 api/client.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/api/client.go b/api/client.go
index f87ea0fd..3dffce60 100644
--- a/api/client.go
+++ b/api/client.go
@@ -10,7 +10,7 @@
 // repository].
 //
 // [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
-// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
+// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/api/examples
 package api
 
 import (

From 2412adf42b8380748ac79476e273f5b337c3b977 Mon Sep 17 00:00:00 2001
From: Blake Mizerany 
Date: Thu, 27 Feb 2025 12:04:53 -0800
Subject: [PATCH 11/20] server/internal: replace model delete API with new
 registry handler. (#9347)

This commit introduces a new API implementation for handling
interactions with the registry and the local model cache. The new API is
located in server/internal/registry. The package name is "registry" and
should be considered temporary; it is hidden and not bleeding outside of
the server package. As the commits roll in, we'll start consuming more
of the API and then let reverse osmosis take effect, at which point it
will surface closer to the root level packages as much as needed.
---
 go.mod                                        |  14 +-
 go.sum                                        |  28 +--
 server/internal/cache/blob/cache.go           |  28 ++-
 server/internal/cache/blob/cache_test.go      |   7 +-
 server/internal/cache/blob/casecheck_test.go  |   2 +-
 server/internal/client/ollama/registry.go     |  94 ++++++--
 .../internal/client/ollama/registry_test.go   | 120 ++++++++--
 server/internal/cmd/opp/opp.go                |   4 +-
 server/internal/registry/server.go            | 215 ++++++++++++++++++
 server/internal/registry/server_test.go       | 168 ++++++++++++++
 ...28784d7106b60a4eb1d246e950a001a3f944fbda99 | Bin 0 -> 24 bytes
 ...cdc3840217bd502124a9d3687d438c19b3cb9c3cb1 |   1 +
 .../registry.ollama.ai/library/smol/latest    |   1 +
 .../{internal => }/testutil/testutil.go       |  28 +++
 server/routes.go                              |  39 +++-
 server/routes_test.go                         |  46 +++-
 16 files changed, 705 insertions(+), 90 deletions(-)
 create mode 100644 server/internal/registry/server.go
 create mode 100644 server/internal/registry/server_test.go
 create mode 100644 server/internal/registry/testdata/models/blobs/sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99
 create mode 100644 server/internal/registry/testdata/models/blobs/sha256-ecfb1acfca9c76444d622fcdc3840217bd502124a9d3687d438c19b3cb9c3cb1
 create mode 100644 server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest
 rename server/internal/{internal => }/testutil/testutil.go (72%)

diff --git a/go.mod b/go.mod
index a6107a62..5f08aad0 100644
--- a/go.mod
+++ b/go.mod
@@ -11,7 +11,7 @@ require (
 	github.com/spf13/cobra v1.7.0
 	github.com/stretchr/testify v1.9.0
 	github.com/x448/float16 v0.8.4
-	golang.org/x/sync v0.10.0
+	golang.org/x/sync v0.11.0
 )
 
 require (
@@ -69,12 +69,12 @@ require (
 	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 	github.com/ugorji/go/codec v1.2.12 // indirect
 	golang.org/x/arch v0.8.0 // indirect
-	golang.org/x/crypto v0.31.0
-	golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
-	golang.org/x/net v0.25.0 // indirect
-	golang.org/x/sys v0.28.0
-	golang.org/x/term v0.27.0
-	golang.org/x/text v0.21.0
+	golang.org/x/crypto v0.33.0
+	golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
+	golang.org/x/net v0.35.0 // indirect
+	golang.org/x/sys v0.30.0
+	golang.org/x/term v0.29.0
+	golang.org/x/text v0.22.0
 	google.golang.org/protobuf v1.34.1
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )
diff --git a/go.sum b/go.sum
index 8eb8d84a..013a7db7 100644
--- a/go.sum
+++ b/go.sum
@@ -214,16 +214,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
 golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
-golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
+golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
+golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
 golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
-golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
-golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
+golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
+golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
 golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
 golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
 golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
 golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
 golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
-golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
-golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
+golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
+golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
 golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
 golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
-golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
+golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
+golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
 golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
-golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
+golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
+golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
-golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
+golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
+golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
 golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
-golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
+golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
+golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
 golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
 golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
diff --git a/server/internal/cache/blob/cache.go b/server/internal/cache/blob/cache.go
index f0b0760f..8a828772 100644
--- a/server/internal/cache/blob/cache.go
+++ b/server/internal/cache/blob/cache.go
@@ -279,6 +279,18 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
 // It returns an error if either the name or digest is invalid, or if link
 // creation encounters any issues.
 func (c *DiskCache) Link(name string, d Digest) error {
+	// TODO(bmizerany): Move link handling from cache to registry.
+	//
+	// We originally placed links in the cache due to its storage
+	// knowledge. However, the registry likely offers better context for
+	// naming concerns, and our API design shouldn't be tightly coupled to
+	// our on-disk format.
+	//
+	// Links work effectively when independent from physical location -
+	// they can reference content with matching SHA regardless of storage
+	// location. In an upcoming change, we plan to shift this
+	// responsibility to the registry where it better aligns with the
+	// system's conceptual model.
 	manifest, err := c.manifestPath(name)
 	if err != nil {
 		return err
@@ -304,21 +316,19 @@ func (c *DiskCache) Link(name string, d Digest) error {
 	return c.copyNamedFile(manifest, f, d, info.Size())
 }
 
-// Unlink removes the any link for name. If the link does not exist, nothing
-// happens, and no error is returned.
-//
-// It returns an error if the name is invalid or if the link removal encounters
-// any issues.
-func (c *DiskCache) Unlink(name string) error {
+// Unlink unlinks the manifest by name from the cache. If the name is not
+// found. If a manifest is removed ok will be true, otherwise false. If an
+// error occurs, it returns ok false, and the error.
+func (c *DiskCache) Unlink(name string) (ok bool, _ error) {
 	manifest, err := c.manifestPath(name)
 	if err != nil {
-		return err
+		return false, err
 	}
 	err = os.Remove(manifest)
 	if errors.Is(err, fs.ErrNotExist) {
-		return nil
+		return false, nil
 	}
-	return err
+	return true, err
 }
 
 // GetFile returns the absolute path to the file, in the cache, for the given
diff --git a/server/internal/cache/blob/cache_test.go b/server/internal/cache/blob/cache_test.go
index 704542ea..af29a312 100644
--- a/server/internal/cache/blob/cache_test.go
+++ b/server/internal/cache/blob/cache_test.go
@@ -13,7 +13,7 @@ import (
 	"testing"
 	"time"
 
-	"github.com/ollama/ollama/server/internal/internal/testutil"
+	"github.com/ollama/ollama/server/internal/testutil"
 )
 
 func init() {
@@ -479,8 +479,11 @@ func testManifestNameReuse(t *testing.T) {
 	}
 
 	// relink with different case
-	err = c.Unlink("h/n/m:t")
+	unlinked, err := c.Unlink("h/n/m:t")
 	check(err)
+	if !unlinked {
+		t.Fatal("expected unlinked")
+	}
 	err = c.Link("h/n/m:T", d1)
 	check(err)
 
diff --git a/server/internal/cache/blob/casecheck_test.go b/server/internal/cache/blob/casecheck_test.go
index f0842ef9..5895d2cb 100644
--- a/server/internal/cache/blob/casecheck_test.go
+++ b/server/internal/cache/blob/casecheck_test.go
@@ -86,7 +86,7 @@ func useCaseInsensitiveTempDir(t *testing.T) bool {
 		// link to docs on that topic.
 		lines := strings.Split(volumeHint, "\n")
 		for _, line := range lines {
-			t.Log(line)
+			t.Skip(line)
 		}
 	}
 	return false
diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go
index 13612272..d4d58ed6 100644
--- a/server/internal/client/ollama/registry.go
+++ b/server/internal/client/ollama/registry.go
@@ -19,6 +19,7 @@ import (
 	"fmt"
 	"io"
 	"io/fs"
+	"log/slog"
 	"net/http"
 	"os"
 	"path/filepath"
@@ -86,9 +87,23 @@ func DefaultCache() (*blob.DiskCache, error) {
 	return blob.Open(dir)
 }
 
-// Error is the standard error returned by Ollama APIs.
+// Error is the standard error returned by Ollama APIs. It can represent a
+// single or multiple error response.
+//
+// Single error responses have the following format:
+//
+//	{"code": "optional_code","error":"error message"}
+//
+// Multiple error responses have the following format:
+//
+//	{"errors": [{"code": "optional_code","message":"error message"}]}
+//
+// Note, that the error field is used in single error responses, while the
+// message field is used in multiple error responses.
+//
+// In both cases, the code field is optional and may be empty.
 type Error struct {
-	Status  int    `json:"-"`
+	Status  int    `json:"-"` // TODO(bmizerany): remove this
 	Code    string `json:"code"`
 	Message string `json:"message"`
 }
@@ -97,13 +112,34 @@ func (e *Error) Error() string {
 	return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
 }
 
+func (e *Error) LogValue() slog.Value {
+	return slog.GroupValue(
+		slog.Int("status", e.Status),
+		slog.String("code", e.Code),
+		slog.String("message", e.Message),
+	)
+}
+
 // UnmarshalJSON implements json.Unmarshaler.
 func (e *Error) UnmarshalJSON(b []byte) error {
 	type E Error
-	var v struct{ Errors []E }
+	var v struct {
+		// Single error
+		Code  string
+		Error string
+
+		// Multiple errors
+		Errors []E
+	}
 	if err := json.Unmarshal(b, &v); err != nil {
 		return err
 	}
+	if v.Error != "" {
+		// Single error case
+		e.Code = v.Code
+		e.Message = v.Error
+		return nil
+	}
 	if len(v.Errors) == 0 {
 		return fmt.Errorf("no messages in error response: %s", string(b))
 	}
@@ -111,9 +147,8 @@ func (e *Error) UnmarshalJSON(b []byte) error {
 	return nil
 }
 
-// TODO(bmizerany): make configurable on [Registry]
 var defaultName = func() names.Name {
-	n := names.Parse("ollama.com/library/_:latest")
+	n := names.Parse("registry.ollama.ai/library/_:latest")
 	if !n.IsFullyQualified() {
 		panic("default name is not fully qualified")
 	}
@@ -160,21 +195,26 @@ type Registry struct {
 	//
 	// It is only used when a layer is larger than [MaxChunkingThreshold].
 	MaxChunkSize int64
+
+	// NameMask, if set, is the name used to convert non-fully qualified
+	// names to fully qualified names. If empty, the default mask
+	// ("registry.ollama.ai/library/_:latest") is used.
+	NameMask string
 }
 
-// RegistryFromEnv returns a new Registry configured from the environment. The
+// DefaultRegistry returns a new Registry configured from the environment. The
 // key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
 // value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
 // system's temporary directory.
 //
 // It returns an error if any configuration in the environment is invalid.
-func RegistryFromEnv() (*Registry, error) {
+func DefaultRegistry() (*Registry, error) {
 	home, err := os.UserHomeDir()
 	if err != nil {
 		return nil, err
 	}
 	keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
-	if err != nil {
+	if err != nil && errors.Is(err, fs.ErrNotExist) {
 		return nil, err
 	}
 
@@ -208,9 +248,19 @@ type PushParams struct {
 // any, is invalid.
 //
 // The scheme is returned as provided by [names.ParseExtended].
-func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) {
+func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
+	maskName := defaultName
+	if mask != "" {
+		maskName = names.Parse(mask)
+		if !maskName.IsFullyQualified() {
+			return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask)
+		}
+	}
 	scheme, n, ds := names.ParseExtended(s)
-	n = names.Merge(n, defaultName)
+	if !n.IsValid() {
+		return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
+	}
+	n = names.Merge(n, maskName)
 	if ds != "" {
 		// Digest is present. Validate it.
 		d, err = blob.ParseDigest(ds)
@@ -223,7 +273,7 @@ func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error)
 	// say that digests take precedence over names, and so should there
 	// errors when being parsed.
 	if !n.IsFullyQualified() {
-		return "", names.Name{}, blob.Digest{}, ErrNameInvalid
+		return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
 	}
 
 	scheme = cmp.Or(scheme, "https")
@@ -255,7 +305,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
 		p = &PushParams{}
 	}
 
-	m, err := ResolveLocal(c, cmp.Or(p.From, name))
+	m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
 	if err != nil {
 		return err
 	}
@@ -278,7 +328,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
 
 	t := traceFromContext(ctx)
 
-	scheme, n, _, err := parseName(name)
+	scheme, n, _, err := parseName(name, r.NameMask)
 	if err != nil {
 		// This should never happen since ResolveLocal should have
 		// already validated the name.
@@ -372,7 +422,7 @@ func canRetry(err error) bool {
 // typically slower than splitting the model up across layers, and is mostly
 // utilized for layers of type equal to "application/vnd.ollama.image".
 func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
-	scheme, n, _, err := parseName(name)
+	scheme, n, _, err := parseName(name, r.NameMask)
 	if err != nil {
 		return err
 	}
@@ -520,6 +570,16 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
 	return c.Link(m.Name, md)
 }
 
+// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
+// before attempting to unlink the model.
+func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
+	_, n, _, err := parseName(name, r.NameMask)
+	if err != nil {
+		return false, err
+	}
+	return c.Unlink(n.String())
+}
+
 // Manifest represents a [ollama.com/manifest].
 type Manifest struct {
 	Name   string   `json:"-"` // the canonical name of the model
@@ -590,8 +650,8 @@ type Layer struct {
 
 // ResolveLocal resolves a name to a Manifest in the local cache. The name is
 // parsed using [names.ParseExtended] but the scheme is ignored.
-func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
-	_, n, d, err := parseName(name)
+func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
+	_, n, d, err := parseName(name, r.NameMask)
 	if err != nil {
 		return nil, err
 	}
@@ -617,7 +677,7 @@ func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
 
 // Resolve resolves a name to a Manifest in the remote registry.
 func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
-	scheme, n, d, err := parseName(name)
+	scheme, n, d, err := parseName(name, r.NameMask)
 	if err != nil {
 		return nil, err
 	}
diff --git a/server/internal/client/ollama/registry_test.go b/server/internal/client/ollama/registry_test.go
index d8f2a407..af898c26 100644
--- a/server/internal/client/ollama/registry_test.go
+++ b/server/internal/client/ollama/registry_test.go
@@ -21,7 +21,7 @@ import (
 
 	"github.com/ollama/ollama/server/internal/cache/blob"
 	"github.com/ollama/ollama/server/internal/chunks"
-	"github.com/ollama/ollama/server/internal/internal/testutil"
+	"github.com/ollama/ollama/server/internal/testutil"
 )
 
 func TestManifestMarshalJSON(t *testing.T) {
@@ -37,20 +37,6 @@ func TestManifestMarshalJSON(t *testing.T) {
 	}
 }
 
-func link(c *blob.DiskCache, name string, manifest string) {
-	_, n, _, err := parseName(name)
-	if err != nil {
-		panic(err)
-	}
-	d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
-	if err != nil {
-		panic(err)
-	}
-	if err := c.Link(n.String(), d); err != nil {
-		panic(err)
-	}
-}
-
 var errRoundTrip = errors.New("forced roundtrip error")
 
 type recordRoundTripper http.HandlerFunc
@@ -98,29 +84,44 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
 		}
 	}
 
+	rc := &Registry{
+		HTTPClient: &http.Client{
+			Transport: recordRoundTripper(h),
+		},
+	}
+
+	link := func(name string, manifest string) {
+		_, n, _, err := parseName(name, rc.NameMask)
+		if err != nil {
+			panic(err)
+		}
+		d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
+		if err != nil {
+			panic(err)
+		}
+		if err := c.Link(n.String(), d); err != nil {
+			panic(err)
+		}
+	}
+
 	commit := func(name string, layers ...*Layer) {
 		t.Helper()
 		data, err := json.Marshal(&Manifest{Layers: layers})
 		if err != nil {
 			t.Fatal(err)
 		}
-		link(c, name, string(data))
+		link(name, string(data))
 	}
 
-	link(c, "empty", "")
+	link("empty", "")
 	commit("zero")
 	commit("single", mklayer("exists"))
 	commit("multiple", mklayer("exists"), mklayer("present"))
 	commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
 	commit("null", nil)
 	commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
-	link(c, "invalid", "!!!!!")
+	link("invalid", "!!!!!")
 
-	rc := &Registry{
-		HTTPClient: &http.Client{
-			Transport: recordRoundTripper(h),
-		},
-	}
 	return rc, c
 }
 
@@ -385,7 +386,7 @@ func TestRegistryPullNotCached(t *testing.T) {
 	})
 
 	// Confirm that the layer does not exist locally
-	_, err := ResolveLocal(c, "model")
+	_, err := rc.ResolveLocal(c, "model")
 	checkNotExist(t, err)
 
 	_, err = c.Get(d)
@@ -396,7 +397,7 @@ func TestRegistryPullNotCached(t *testing.T) {
 
 	mw, err := rc.Resolve(t.Context(), "model")
 	check(err)
-	mg, err := ResolveLocal(c, "model")
+	mg, err := rc.ResolveLocal(c, "model")
 	check(err)
 	if !reflect.DeepEqual(mw, mg) {
 		t.Errorf("mw = %v; mg = %v", mw, mg)
@@ -654,3 +655,72 @@ func TestCanRetry(t *testing.T) {
 		}
 	}
 }
+
+func TestErrorUnmarshal(t *testing.T) {
+	cases := []struct {
+		name    string
+		data    string
+		want    *Error
+		wantErr bool
+	}{
+		{
+			name:    "errors empty",
+			data:    `{"errors":[]}`,
+			wantErr: true,
+		},
+		{
+			name:    "errors empty",
+			data:    `{"errors":[]}`,
+			wantErr: true,
+		},
+		{
+			name: "errors single",
+			data: `{"errors":[{"code":"blob_unknown"}]}`,
+			want: &Error{Code: "blob_unknown", Message: ""},
+		},
+		{
+			name: "errors multiple",
+			data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
+			want: &Error{Code: "blob_unknown", Message: ""},
+		},
+		{
+			name:    "error empty",
+			data:    `{"error":""}`,
+			wantErr: true,
+		},
+		{
+			name:    "error very empty",
+			data:    `{}`,
+			wantErr: true,
+		},
+		{
+			name: "error message",
+			data: `{"error":"message", "code":"code"}`,
+			want: &Error{Code: "code", Message: "message"},
+		},
+		{
+			name:    "invalid value",
+			data:    `{"error": 1}`,
+			wantErr: true,
+		},
+	}
+	for _, tt := range cases {
+		t.Run(tt.name, func(t *testing.T) {
+			var got Error
+			err := json.Unmarshal([]byte(tt.data), &got)
+			if err != nil {
+				if tt.wantErr {
+					return
+				}
+				t.Errorf("Unmarshal() error = %v", err)
+				// fallthrough and check got
+			}
+			if tt.want == nil {
+				tt.want = &Error{}
+			}
+			if !reflect.DeepEqual(got, *tt.want) {
+				t.Errorf("got = %v; want %v", got, *tt.want)
+			}
+		})
+	}
+}
diff --git a/server/internal/cmd/opp/opp.go b/server/internal/cmd/opp/opp.go
index 12199cf3..cc10a72f 100644
--- a/server/internal/cmd/opp/opp.go
+++ b/server/internal/cmd/opp/opp.go
@@ -68,7 +68,7 @@ func main() {
 		log.Fatal(err)
 	}
 
-	rc, err := ollama.RegistryFromEnv()
+	rc, err := ollama.DefaultRegistry()
 	if err != nil {
 		log.Fatal(err)
 	}
@@ -177,7 +177,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
 	}
 
 	from := cmp.Or(*flagFrom, model)
-	m, err := ollama.ResolveLocal(c, from)
+	m, err := rc.ResolveLocal(c, from)
 	if err != nil {
 		return err
 	}
diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go
new file mode 100644
index 00000000..8d6dc1aa
--- /dev/null
+++ b/server/internal/registry/server.go
@@ -0,0 +1,215 @@
+// Package registry provides an http.Handler for handling local Ollama API
+// requests for performing tasks related to the ollama.com model registry and
+// the local disk cache.
+package registry
+
+import (
+	"cmp"
+	"encoding/json"
+	"errors"
+	"io"
+	"log/slog"
+	"net/http"
+
+	"github.com/ollama/ollama/server/internal/cache/blob"
+	"github.com/ollama/ollama/server/internal/client/ollama"
+)
+
+// Local is an http.Handler for handling local Ollama API requests for
+// performing tasks related to the ollama.com model registry combined with the
+// local disk cache.
+//
+// It is not concern of Local, or this package, to handle model creation, which
+// proceeds any registry operations for models it produces.
+//
+// NOTE: The package built for dealing with model creation should use
+// [DefaultCache] to access the blob store and not attempt to read or write
+// directly to the blob disk cache.
+type Local struct {
+	Client *ollama.Registry // required
+	Cache  *blob.DiskCache  // required
+	Logger *slog.Logger     // required
+
+	// Fallback, if set, is used to handle requests that are not handled by
+	// this handler.
+	Fallback http.Handler
+}
+
+// serverError is like ollama.Error, but with a Status field for the HTTP
+// response code. We want to avoid adding that field to ollama.Error because it
+// would always be 0 to clients (we don't want to leak the status code in
+// errors), and so it would be confusing to have a field that is always 0.
+type serverError struct {
+	Status int `json:"-"`
+
+	// TODO(bmizerany): Decide if we want to keep this and maybe
+	// bring back later.
+	Code string `json:"code"`
+
+	Message string `json:"error"`
+}
+
+func (e serverError) Error() string {
+	return e.Message
+}
+
+// Common API errors
+var (
+	errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
+	errNotFound         = &serverError{404, "not_found", "not found"}
+	errInternalError    = &serverError{500, "internal_error", "internal server error"}
+)
+
+type statusCodeRecorder struct {
+	_status int // use status() to get the status code
+	http.ResponseWriter
+}
+
+func (r *statusCodeRecorder) WriteHeader(status int) {
+	if r._status == 0 {
+		r._status = status
+	}
+	r.ResponseWriter.WriteHeader(status)
+}
+
+func (r *statusCodeRecorder) status() int {
+	return cmp.Or(r._status, 200)
+}
+
+func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+	rec := &statusCodeRecorder{ResponseWriter: w}
+	s.serveHTTP(rec, r)
+}
+
+func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
+	var errattr slog.Attr
+	proxied, err := func() (bool, error) {
+		switch r.URL.Path {
+		case "/api/delete":
+			return false, s.handleDelete(rec, r)
+		default:
+			if s.Fallback != nil {
+				s.Fallback.ServeHTTP(rec, r)
+				return true, nil
+			}
+			return false, errNotFound
+		}
+	}()
+	if err != nil {
+		// We always log the error, so fill in the error log attribute
+		errattr = slog.String("error", err.Error())
+
+		var e *serverError
+		switch {
+		case errors.As(err, &e):
+		case errors.Is(err, ollama.ErrNameInvalid):
+			e = &serverError{400, "bad_request", err.Error()}
+		default:
+			e = errInternalError
+		}
+
+		data, err := json.Marshal(e)
+		if err != nil {
+			// unreachable
+			panic(err)
+		}
+		rec.Header().Set("Content-Type", "application/json")
+		rec.WriteHeader(e.Status)
+		rec.Write(data)
+
+		// fallthrough to log
+	}
+
+	if !proxied {
+		// we're only responsible for logging if we handled the request
+		var level slog.Level
+		if rec.status() >= 500 {
+			level = slog.LevelError
+		} else if rec.status() >= 400 {
+			level = slog.LevelWarn
+		}
+
+		s.Logger.LogAttrs(r.Context(), level, "http",
+			errattr, // report first in line to make it easy to find
+
+			// TODO(bmizerany): Write a test to ensure that we are logging
+			// all of this correctly. That also goes for the level+error
+			// logic above.
+			slog.Int("status", rec.status()),
+			slog.String("method", r.Method),
+			slog.String("path", r.URL.Path),
+			slog.Int64("content-length", r.ContentLength),
+			slog.String("remote", r.RemoteAddr),
+			slog.String("proto", r.Proto),
+			slog.String("query", r.URL.RawQuery),
+		)
+	}
+}
+
+type params struct {
+	DeprecatedName string `json:"name"`  // Use [params.model]
+	Model          string `json:"model"` // Use [params.model]
+
+	// AllowNonTLS is a flag that indicates a client using HTTP
+	// is doing so, deliberately.
+	//
+	// Deprecated: This field is ignored and only present for this
+	// deprecation message. It should be removed in a future release.
+	//
+	// Users can just use http or https+insecure to show intent to
+	// communicate they want to do insecure things, without awkward and
+	// confusing flags such as this.
+	AllowNonTLS bool `json:"insecure"`
+
+	// ProgressStream is a flag that indicates the client is expecting a stream of
+	// progress updates.
+	ProgressStream bool `json:"stream"`
+}
+
+// model returns the model name for both old and new API requests.
+func (p params) model() string {
+	return cmp.Or(p.Model, p.DeprecatedName)
+}
+
+func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
+	if r.Method != "DELETE" {
+		return errMethodNotAllowed
+	}
+	p, err := decodeUserJSON[*params](r.Body)
+	if err != nil {
+		return err
+	}
+	ok, err := s.Client.Unlink(s.Cache, p.model())
+	if err != nil {
+		return err
+	}
+	if !ok {
+		return &serverError{404, "manifest_not_found", "manifest not found"}
+	}
+	return nil
+}
+
+func decodeUserJSON[T any](r io.Reader) (T, error) {
+	var v T
+	err := json.NewDecoder(r).Decode(&v)
+	if err == nil {
+		return v, nil
+	}
+	var zero T
+
+	// Not sure why, but I can't seem to be able to use:
+	//
+	//   errors.As(err, &json.UnmarshalTypeError{})
+	//
+	// This is working fine in stdlib, so I'm not sure what rules changed
+	// and why this no longer works here. So, we do it the verbose way.
+	var a *json.UnmarshalTypeError
+	var b *json.SyntaxError
+	if errors.As(err, &a) || errors.As(err, &b) {
+		err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
+	}
+	if errors.Is(err, io.EOF) {
+		err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
+	}
+	return zero, err
+}
diff --git a/server/internal/registry/server_test.go b/server/internal/registry/server_test.go
new file mode 100644
index 00000000..22267ba7
--- /dev/null
+++ b/server/internal/registry/server_test.go
@@ -0,0 +1,168 @@
+package registry
+
+import (
+	"encoding/json"
+	"net/http"
+	"net/http/httptest"
+	"os"
+	"regexp"
+	"strings"
+	"testing"
+
+	"github.com/ollama/ollama/server/internal/cache/blob"
+	"github.com/ollama/ollama/server/internal/client/ollama"
+	"github.com/ollama/ollama/server/internal/testutil"
+)
+
+type panicTransport struct{}
+
+func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
+	panic("unexpected RoundTrip call")
+}
+
+var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
+
+// bytesResetter is an interface for types that can be reset and return a byte
+// slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
+// etc for the purpose of checking logs.
+type bytesResetter interface {
+	Bytes() []byte
+	Reset()
+}
+
+func newTestServer(t *testing.T) *Local {
+	t.Helper()
+	dir := t.TempDir()
+	err := os.CopyFS(dir, os.DirFS("testdata/models"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	c, err := blob.Open(dir)
+	if err != nil {
+		t.Fatal(err)
+	}
+	rc := &ollama.Registry{
+		HTTPClient: panicOnRoundTrip,
+	}
+	l := &Local{
+		Cache:  c,
+		Client: rc,
+		Logger: testutil.Slogger(t),
+	}
+	return l
+}
+
+func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
+	t.Helper()
+	req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
+	return s.sendRequest(t, req)
+}
+
+func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
+	t.Helper()
+	w := httptest.NewRecorder()
+	s.ServeHTTP(w, req)
+	return w
+}
+
+type invalidReader struct{}
+
+func (r *invalidReader) Read(p []byte) (int, error) {
+	return 0, os.ErrInvalid
+}
+
+// captureLogs is a helper to capture logs from the server. It returns a
+// shallow copy of the server with a new logger and a bytesResetter for the
+// logs.
+func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
+	t.Helper()
+	log, logs := testutil.SlogBuffer()
+	l := *s // shallow copy
+	l.Logger = log
+	return &l, logs
+}
+
+func TestServerDelete(t *testing.T) {
+	check := testutil.Checker(t)
+
+	s := newTestServer(t)
+
+	_, err := s.Client.ResolveLocal(s.Cache, "smol")
+	check(err)
+
+	got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
+	if got.Code != 200 {
+		t.Fatalf("Code = %d; want 200", got.Code)
+	}
+
+	_, err = s.Client.ResolveLocal(s.Cache, "smol")
+	if err == nil {
+		t.Fatal("expected smol to have been deleted")
+	}
+
+	got = s.send(t, "DELETE", "/api/delete", `!`)
+	checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
+
+	got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
+	checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
+
+	got = s.send(t, "DELETE", "/api/delete", ``)
+	checkErrorResponse(t, got, 400, "bad_request", "empty request body")
+
+	got = s.send(t, "DELETE", "/api/delete", `{"model": "!"}`)
+	checkErrorResponse(t, got, 404, "manifest_not_found", "not found")
+
+	got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
+	checkErrorResponse(t, got, 400, "bad_request", "invalid name")
+
+	got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
+	checkErrorResponse(t, got, 404, "not_found", "not found")
+
+	s, logs := captureLogs(t, s)
+	req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
+	got = s.sendRequest(t, req)
+	checkErrorResponse(t, got, 500, "internal_error", "internal server error")
+	ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
+	check(err)
+	if !ok {
+		t.Logf("logs:\n%s", logs)
+		t.Fatalf("expected log to contain ERROR with invalid argument")
+	}
+}
+
+func TestServerUnknownPath(t *testing.T) {
+	s := newTestServer(t)
+	got := s.send(t, "DELETE", "/api/unknown", `{}`)
+	checkErrorResponse(t, got, 404, "not_found", "not found")
+}
+
+func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
+	t.Helper()
+
+	var printedBody bool
+	errorf := func(format string, args ...any) {
+		t.Helper()
+		if !printedBody {
+			t.Logf("BODY:\n%s", got.Body.String())
+			printedBody = true
+		}
+		t.Errorf(format, args...)
+	}
+
+	if got.Code != status {
+		errorf("Code = %d; want %d", got.Code, status)
+	}
+
+	// unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
+	var e *ollama.Error
+	if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
+		errorf("unmarshal error: %v", err)
+		t.FailNow()
+	}
+	if e.Code != code {
+		errorf("Code = %q; want %q", e.Code, code)
+	}
+	if !strings.Contains(e.Message, msg) {
+		errorf("Message = %q; want to contain %q", e.Message, msg)
+	}
+}
diff --git a/server/internal/registry/testdata/models/blobs/sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99 b/server/internal/registry/testdata/models/blobs/sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99
new file mode 100644
index 0000000000000000000000000000000000000000..def4dffc777741d0c39e37dc05816e45e7c329ce
GIT binary patch
literal 24
OcmZ>F4|QW^zytsqX#p(&

literal 0
HcmV?d00001

diff --git a/server/internal/registry/testdata/models/blobs/sha256-ecfb1acfca9c76444d622fcdc3840217bd502124a9d3687d438c19b3cb9c3cb1 b/server/internal/registry/testdata/models/blobs/sha256-ecfb1acfca9c76444d622fcdc3840217bd502124a9d3687d438c19b3cb9c3cb1
new file mode 100644
index 00000000..62114d06
--- /dev/null
+++ b/server/internal/registry/testdata/models/blobs/sha256-ecfb1acfca9c76444d622fcdc3840217bd502124a9d3687d438c19b3cb9c3cb1
@@ -0,0 +1 @@
+{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}
\ No newline at end of file
diff --git a/server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest b/server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest
new file mode 100644
index 00000000..62114d06
--- /dev/null
+++ b/server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest
@@ -0,0 +1 @@
+{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}
\ No newline at end of file
diff --git a/server/internal/internal/testutil/testutil.go b/server/internal/testutil/testutil.go
similarity index 72%
rename from server/internal/internal/testutil/testutil.go
rename to server/internal/testutil/testutil.go
index 354c2608..f01df942 100644
--- a/server/internal/internal/testutil/testutil.go
+++ b/server/internal/testutil/testutil.go
@@ -1,12 +1,40 @@
 package testutil
 
 import (
+	"bytes"
+	"io"
+	"log/slog"
 	"os"
 	"path/filepath"
 	"testing"
 	"time"
 )
 
+// LogWriter returns an [io.Writer] that logs each Write using t.Log.
+func LogWriter(t *testing.T) io.Writer {
+	return testWriter{t}
+}
+
+type testWriter struct{ t *testing.T }
+
+func (w testWriter) Write(b []byte) (int, error) {
+	w.t.Logf("%s", b)
+	return len(b), nil
+}
+
+// Slogger returns a [*slog.Logger] that writes each message
+// using t.Log.
+func Slogger(t *testing.T) *slog.Logger {
+	return slog.New(slog.NewTextHandler(LogWriter(t), nil))
+}
+
+// SlogBuffer returns a [*slog.Logger] that writes each message to out.
+func SlogBuffer() (lg *slog.Logger, out *bytes.Buffer) {
+	var buf bytes.Buffer
+	lg = slog.New(slog.NewTextHandler(&buf, nil))
+	return lg, &buf
+}
+
 // Check calls t.Fatal(err) if err is not nil.
 func Check(t *testing.T, err error) {
 	if err != nil {
diff --git a/server/routes.go b/server/routes.go
index de72f847..ff42000f 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -34,6 +34,9 @@ import (
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/model/models/mllama"
 	"github.com/ollama/ollama/openai"
+	"github.com/ollama/ollama/server/internal/cache/blob"
+	"github.com/ollama/ollama/server/internal/client/ollama"
+	"github.com/ollama/ollama/server/internal/registry"
 	"github.com/ollama/ollama/template"
 	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
@@ -1126,7 +1129,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
 	}
 }
 
-func (s *Server) GenerateRoutes() http.Handler {
+func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
 	corsConfig := cors.DefaultConfig()
 	corsConfig.AllowWildcard = true
 	corsConfig.AllowBrowserExtensions = true
@@ -1165,10 +1168,9 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
 	r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
 
-	// Local model cache management
+	// Local model cache management (new implementation is at end of function)
 	r.POST("/api/pull", s.PullHandler)
 	r.POST("/api/push", s.PushHandler)
-	r.DELETE("/api/delete", s.DeleteHandler)
 	r.HEAD("/api/tags", s.ListHandler)
 	r.GET("/api/tags", s.ListHandler)
 	r.POST("/api/show", s.ShowHandler)
@@ -1193,7 +1195,15 @@ func (s *Server) GenerateRoutes() http.Handler {
 	r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
 	r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
 
-	return r
+	// wrap old with new
+	rs := ®istry.Local{
+		Cache:    c,
+		Client:   rc,
+		Logger:   slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
+		Fallback: r,
+	}
+
+	return rs, nil
 }
 
 func Serve(ln net.Listener) error {
@@ -1246,12 +1256,27 @@ func Serve(ln net.Listener) error {
 		}
 	}
 
+	s := &Server{addr: ln.Addr()}
+
+	c, err := ollama.DefaultCache()
+	if err != nil {
+		return err
+	}
+	rc, err := ollama.DefaultRegistry()
+	if err != nil {
+		return err
+	}
+
+	h, err := s.GenerateRoutes(c, rc)
+	if err != nil {
+		return err
+	}
+	http.Handle("/", h)
+
 	ctx, done := context.WithCancel(context.Background())
 	schedCtx, schedDone := context.WithCancel(ctx)
 	sched := InitScheduler(schedCtx)
-	s := &Server{addr: ln.Addr(), sched: sched}
-
-	http.Handle("/", s.GenerateRoutes())
+	s.sched = sched
 
 	slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
 	srvr := &http.Server{
diff --git a/server/routes_test.go b/server/routes_test.go
index c15fc0a5..0dd782f4 100644
--- a/server/routes_test.go
+++ b/server/routes_test.go
@@ -23,6 +23,8 @@ import (
 	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/fs/ggml"
 	"github.com/ollama/ollama/openai"
+	"github.com/ollama/ollama/server/internal/cache/blob"
+	"github.com/ollama/ollama/server/internal/client/ollama"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 )
@@ -91,7 +93,15 @@ func equalStringSlices(a, b []string) bool {
 	return true
 }
 
-func Test_Routes(t *testing.T) {
+type panicTransport struct{}
+
+func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
+	panic("unexpected RoundTrip call")
+}
+
+var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
+
+func TestRoutes(t *testing.T) {
 	type testCase struct {
 		Name     string
 		Method   string
@@ -241,10 +251,10 @@ func Test_Routes(t *testing.T) {
 			Method: http.MethodDelete,
 			Path:   "/api/delete",
 			Setup: func(t *testing.T, req *http.Request) {
-				createTestModel(t, "model-to-delete")
+				createTestModel(t, "model_to_delete")
 
 				deleteReq := api.DeleteRequest{
-					Name: "model-to-delete",
+					Name: "model_to_delete",
 				}
 				jsonData, err := json.Marshal(deleteReq)
 				if err != nil {
@@ -271,7 +281,7 @@ func Test_Routes(t *testing.T) {
 			Path:   "/api/delete",
 			Setup: func(t *testing.T, req *http.Request) {
 				deleteReq := api.DeleteRequest{
-					Name: "non-existent-model",
+					Name: "non_existent_model",
 				}
 				jsonData, err := json.Marshal(deleteReq)
 				if err != nil {
@@ -477,10 +487,34 @@ func Test_Routes(t *testing.T) {
 		},
 	}
 
-	t.Setenv("OLLAMA_MODELS", t.TempDir())
+	modelsDir := t.TempDir()
+	t.Setenv("OLLAMA_MODELS", modelsDir)
+
+	c, err := blob.Open(modelsDir)
+	if err != nil {
+		t.Fatalf("failed to open models dir: %v", err)
+	}
+
+	rc := &ollama.Registry{
+		// This is a temporary measure to allow us to move forward,
+		// surfacing any code contacting ollama.com we do not intended
+		// to.
+		//
+		// Currently, this only handles DELETE /api/delete, which
+		// should not make any contact with the ollama.com registry, so
+		// be clear about that.
+		//
+		// Tests that do need to contact the registry here, will be
+		// consumed into our new server/api code packages and removed
+		// from here.
+		HTTPClient: panicOnRoundTrip,
+	}
 
 	s := &Server{}
-	router := s.GenerateRoutes()
+	router, err := s.GenerateRoutes(c, rc)
+	if err != nil {
+		t.Fatalf("failed to generate routes: %v", err)
+	}
 
 	httpSrv := httptest.NewServer(router)
 	t.Cleanup(httpSrv.Close)

From e185c08ad9c71c82c4bef970a9c1f732e1b68076 Mon Sep 17 00:00:00 2001
From: Jesse Gross 
Date: Thu, 27 Feb 2025 11:35:37 -0800
Subject: [PATCH 12/20] go.mod: Use full version for go 1.24.0

Otherwise on Linux I get:
go: download go1.24 for linux/amd64: toolchain not available
---
 go.mod | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/go.mod b/go.mod
index 5f08aad0..af0cedc8 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
 module github.com/ollama/ollama
 
-go 1.24
+go 1.24.0
 
 require (
 	github.com/containerd/console v1.0.3

From 53d2990d9b60fae08437e98141eca5d9e393deaa Mon Sep 17 00:00:00 2001
From: Michael Yang 
Date: Wed, 26 Feb 2025 11:20:51 -0800
Subject: [PATCH 13/20] model: add bos token if configured

---
 fs/ggml/ggml.go              |  6 +++++-
 ml/backend.go                |  1 +
 model/models/llama/model.go  |  2 ++
 model/models/mllama/model.go |  2 ++
 model/process_text.go        | 23 ++++++++++++++++++++++-
 5 files changed, 32 insertions(+), 2 deletions(-)

diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go
index 57313859..b9f9cc17 100644
--- a/fs/ggml/ggml.go
+++ b/fs/ggml/ggml.go
@@ -100,6 +100,10 @@ func (kv KV) Float(key string, defaultValue ...float32) float32 {
 	return keyValue(kv, key, append(defaultValue, 0)...)
 }
 
+func (kv KV) Bool(key string, defaultValue ...bool) bool {
+	return keyValue(kv, key, append(defaultValue, false)...)
+}
+
 func (kv KV) Strings(key string, defaultValue ...[]string) []string {
 	r := keyValue(kv, key, &array{})
 	s := make([]string, r.size)
@@ -120,7 +124,7 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
 	return s
 }
 
-func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
+func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
 	if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
 		key = kv.Architecture() + "." + key
 	}
diff --git a/ml/backend.go b/ml/backend.go
index 6e3f0516..a742ee5c 100644
--- a/ml/backend.go
+++ b/ml/backend.go
@@ -14,6 +14,7 @@ type Config interface {
 	String(string, ...string) string
 	Uint(string, ...uint32) uint32
 	Float(string, ...float32) float32
+	Bool(string, ...bool) bool
 
 	Strings(string, ...[]string) []string
 	Uints(string, ...[]uint32) []uint32
diff --git a/model/models/llama/model.go b/model/models/llama/model.go
index 4fe02999..6106af86 100644
--- a/model/models/llama/model.go
+++ b/model/models/llama/model.go
@@ -37,7 +37,9 @@ func New(c ml.Config) (model.Model, error) {
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				Merges: c.Strings("tokenizer.ggml.merges"),
 				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
+				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
 				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
+				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
 			},
 		),
 		Layers: make([]Layer, c.Uint("block_count")),
diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go
index f5521ce5..9b35a262 100644
--- a/model/models/mllama/model.go
+++ b/model/models/mllama/model.go
@@ -33,7 +33,9 @@ func New(c ml.Config) (model.Model, error) {
 				Types:  c.Uints("tokenizer.ggml.token_type"),
 				Merges: c.Strings("tokenizer.ggml.merges"),
 				BOS:    int32(c.Uint("tokenizer.ggml.bos_token_id")),
+				AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
 				EOS:    int32(c.Uint("tokenizer.ggml.eos_token_id")),
+				AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
 			},
 		),
 		ImageProcessor: newImageProcessor(c),
diff --git a/model/process_text.go b/model/process_text.go
index df1e68f4..7083f36f 100644
--- a/model/process_text.go
+++ b/model/process_text.go
@@ -30,7 +30,8 @@ type Vocabulary struct {
 	Scores []uint32
 	Merges []string
 
-	BOS, EOS int32
+	BOS, EOS       int32
+	AddBOS, AddEOS bool
 
 	specialOnce sync.Once
 	special     []string
@@ -281,6 +282,26 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
 		}
 	}
 
+	if len(ids) > 0 {
+		if bpe.vocab.AddBOS {
+			if ids[0] == bpe.vocab.BOS {
+				slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
+			}
+
+			slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS)
+			ids = append([]int32{bpe.vocab.BOS}, ids...)
+		}
+
+		if bpe.vocab.AddEOS {
+			if ids[len(ids)-1] == bpe.vocab.EOS {
+				slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS)
+			}
+
+			slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS)
+			ids = append(ids, bpe.vocab.EOS)
+		}
+	}
+
 	return ids, nil
 }
 

From 41dc280491a7054876e826dac0eff66836d53ae8 Mon Sep 17 00:00:00 2001
From: Blake Mizerany 
Date: Thu, 27 Feb 2025 14:00:37 -0800
Subject: [PATCH 14/20] server/internal/registry: implement CloseNotify and
 Flush (for now) (#9402)

This fixes panics introduced in 2412adf42b8380748ac79476e273f5b337c3b977
when Gin ungracefully assumes that the http.ResponseWriter implements
http.CloseNotifier and http.Flusher, which our new statusCodeRecorder
does not. This is a temporary fix until we can pour the rest of the Gin
out.
---
 server/internal/registry/server.go | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go
index 8d6dc1aa..8eb6daf8 100644
--- a/server/internal/registry/server.go
+++ b/server/internal/registry/server.go
@@ -72,6 +72,26 @@ func (r *statusCodeRecorder) WriteHeader(status int) {
 	r.ResponseWriter.WriteHeader(status)
 }
 
+var (
+	_ http.ResponseWriter = (*statusCodeRecorder)(nil)
+	_ http.CloseNotifier  = (*statusCodeRecorder)(nil)
+	_ http.Flusher        = (*statusCodeRecorder)(nil)
+)
+
+// CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
+//
+// It panics if the underlying ResponseWriter is not a CloseNotifier.
+func (r *statusCodeRecorder) CloseNotify() <-chan bool {
+	return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
+}
+
+// Flush implements the http.Flusher interface, for Gin. Remove with Gin.
+//
+// It panics if the underlying ResponseWriter is not a Flusher.
+func (r *statusCodeRecorder) Flush() {
+	r.ResponseWriter.(http.Flusher).Flush()
+}
+
 func (r *statusCodeRecorder) status() int {
 	return cmp.Or(r._status, 200)
 }

From 3e8b8a1933378eb62c68bc9269555efd10270e33 Mon Sep 17 00:00:00 2001
From: Michael Yang 
Date: Fri, 21 Feb 2025 11:57:08 -0800
Subject: [PATCH 15/20] ml: update Context.Forward interface

update Context.Forward to accept multiple tensors to match
Context.Compute signature

update Context.Forward to return Context such that it can be chained
with Context.Compute
---
 kvcache/causal.go       | 6 ++++--
 kvcache/causal_test.go  | 4 +---
 kvcache/encoder.go      | 6 ++++--
 ml/backend.go           | 5 ++---
 ml/backend/ggml/ggml.go | 8 ++++++--
 model/model.go          | 3 +--
 6 files changed, 18 insertions(+), 14 deletions(-)

diff --git a/kvcache/causal.go b/kvcache/causal.go
index 5d46f8d4..69068439 100644
--- a/kvcache/causal.go
+++ b/kvcache/causal.go
@@ -330,8 +330,10 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
 		c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
 	}
 
-	ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
-	ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
+	ctx.Forward(
+		key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
+		value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
+	)
 }
 
 func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go
index 874e4743..bbbdf836 100644
--- a/kvcache/causal_test.go
+++ b/kvcache/causal_test.go
@@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
 
 			out, _, mask := cache.Get(context)
 
-			context.Forward(out)
-			context.Forward(mask)
-			context.Compute(out, mask)
+			context.Forward(out, mask).Compute(out, mask)
 
 			if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
 				t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
diff --git a/kvcache/encoder.go b/kvcache/encoder.go
index 8a44c194..b85b1046 100644
--- a/kvcache/encoder.go
+++ b/kvcache/encoder.go
@@ -80,8 +80,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
 		c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
 	}
 
-	ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
-	ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
+	ctx.Forward(
+		key.Copy(ctx, c.keys[c.curLayer]),
+		value.Copy(ctx, c.values[c.curLayer]),
+	)
 }
 
 func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
diff --git a/ml/backend.go b/ml/backend.go
index a742ee5c..07bc75b6 100644
--- a/ml/backend.go
+++ b/ml/backend.go
@@ -65,7 +65,7 @@ type Context interface {
 	FromFloatSlice(s []float32, shape ...int) (Tensor, error)
 	FromIntSlice(s []int32, shape ...int) (Tensor, error)
 
-	Forward(Tensor)
+	Forward(...Tensor) Context
 	Compute(...Tensor)
 	MaxTensors() int
 	Close()
@@ -186,8 +186,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
 
 func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
 	if t.Bytes() == nil {
-		ctx.Forward(t)
-		ctx.Compute(t)
+		ctx.Forward(t).Compute(t)
 	}
 
 	s := make(S, mul(t.Shape()...))
diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go
index 2d7cf340..7f91990c 100644
--- a/ml/backend/ggml/ggml.go
+++ b/ml/backend/ggml/ggml.go
@@ -256,12 +256,16 @@ type Context struct {
 	nodes int
 }
 
-func (c *Context) Forward(t ml.Tensor) {
+func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
 	if c.graph == nil {
 		c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
 	}
 
-	C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
+	for _, tensor := range tensors {
+		C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
+	}
+
+	return c
 }
 
 func (c *Context) Compute(tensors ...ml.Tensor) {
diff --git a/model/model.go b/model/model.go
index 0b5996d9..16020b35 100644
--- a/model/model.go
+++ b/model/model.go
@@ -248,8 +248,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	ctx.Forward(t)
-	ctx.Compute(t)
+	ctx.Forward(t).Compute(t)
 
 	return t, nil
 }

From 8b194b752042198c061f2f53797dfde5f9ac0d88 Mon Sep 17 00:00:00 2001
From: Michael Yang 
Date: Wed, 26 Feb 2025 12:16:59 -0800
Subject: [PATCH 16/20] kvcache: update tests

---
 kvcache/causal_test.go | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go
index bbbdf836..bd7d0ae8 100644
--- a/kvcache/causal_test.go
+++ b/kvcache/causal_test.go
@@ -342,7 +342,7 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
 	return out, nil
 }
 
-func (c *testContext) Forward(ml.Tensor) {}
+func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
 
 func (c *testContext) Compute(...ml.Tensor) {}
 

From c245b0406fd669bc8e3aea4e20148fa303fe2fd4 Mon Sep 17 00:00:00 2001
From: Parth Sareen 
Date: Thu, 27 Feb 2025 15:44:53 -0800
Subject: [PATCH 17/20] sample: remove transforms from greedy sampling (#9377)

---
 sample/samplers.go      | 53 ++++++++----------------
 sample/samplers_test.go | 89 ++++++++++++++++++-----------------------
 2 files changed, 55 insertions(+), 87 deletions(-)

diff --git a/sample/samplers.go b/sample/samplers.go
index 836c6e4d..1b8a5edd 100644
--- a/sample/samplers.go
+++ b/sample/samplers.go
@@ -54,53 +54,42 @@ func (s weighted) Sample(logits []float32) (int32, error) {
 	if idx, ok := w.Take(); ok {
 		return int32(indices[idx]), nil
 	}
-	return -1, errors.New("weighed sampler failed, no valid token found")
+	return -1, errors.New("weighted sampler failed, no valid token found")
 }
 
-type greedy struct {
-	transforms []Transform
-}
-
-func Greedy(transforms ...Transform) Sampler {
-	return greedy{transforms: transforms}
+type greedy struct{}
+
+func Greedy() Sampler {
+	return greedy{}
 }
 
+// Sample returns the index of the maximum value in logits.
 func (s greedy) Sample(logits []float32) (int32, error) {
-	logits64 := make([]float64, len(logits))
-	for i, v := range logits {
-		logits64[i] = float64(v)
+	if len(logits) == 0 {
+		return -1, errors.New("no logits provided for greedy sampling")
 	}
 
-	for _, t := range s.transforms {
-		logits64 = t.Apply(logits64)
-	}
-
-	var maxIdx int
-	var maxLogit float64
-	for i, logit := range logits64 {
-		if logit > maxLogit {
-			maxLogit = logit
+	maxIdx := 0
+	for i := range logits {
+		if logits[i] > logits[maxIdx] {
 			maxIdx = i
 		}
 	}
 
-	if maxLogit == math.Inf(-1) {
-		return -1, errors.New("no valid logits found for greedy sampling")
-	}
-
 	return int32(maxIdx), nil
 }
 
 // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
 func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
-	transforms := []Transform{}
+	if temperature == 0 {
+		return Greedy(), nil
+	}
+
 	if temperature < 0 || temperature > 2 {
 		return nil, errors.New("temperature must be between 0 and 2")
 	}
 
-	if temperature != 0 {
-		transforms = append(transforms, Temperature(temperature))
-	}
+	transforms := []Transform{Temperature(temperature)}
 
 	if topK != 0 {
 		if topK <= 0 {
@@ -123,15 +112,7 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
 		transforms = append(transforms, MinP(minP))
 	}
 
-	if len(transforms) == 0 {
-		return nil, errors.New("at least one transform is required")
-	}
-
-	if temperature == 0 {
-		return Greedy(transforms...), nil
-	}
-
-	if seed != 0 {
+	if seed >= 0 {
 		seed64 := uint64(seed)
 		return Weighted(&seed64, transforms...), nil
 	}
diff --git a/sample/samplers_test.go b/sample/samplers_test.go
index aaa8d99c..32364a3b 100644
--- a/sample/samplers_test.go
+++ b/sample/samplers_test.go
@@ -66,32 +66,15 @@ func TestSample(t *testing.T) {
 		callOrder: &callOrder,
 	}
 
-	got, err := Greedy(mock1, mock2, mock3).Sample(input)
+	_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
 	if err != nil {
 		t.Error(err)
 		return
 	}
-
-	want := int32(3) // Greedy sampler should pick highest logit
-	if want != got {
-		t.Errorf("index mismatch: want %d, got %d", want, got)
-	}
 	wantOrder := []int{1, 2, 3}
 	if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
 		t.Errorf("call order mismatch (-want +got):\n%s", diff)
 	}
-
-	callOrder = nil
-
-	_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
-	if err != nil {
-		t.Error(err)
-		return
-	}
-	wantOrder = []int{1, 2, 3}
-	if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
-		t.Errorf("call order mismatch (-want +got):\n%s", diff)
-	}
 }
 
 func TestNewSampler(t *testing.T) {
@@ -105,8 +88,9 @@ func TestNewSampler(t *testing.T) {
 		wantErr     bool
 	}{
 		{
-			name:    "no transforms",
-			wantErr: true,
+			name: "no transforms",
+			// temperature is 0, so greedy should be used
+			wantErr: false,
 		},
 		{
 			name:        "temperature",
@@ -124,49 +108,52 @@ func TestNewSampler(t *testing.T) {
 			wantErr:     true,
 		},
 		{
-			name:    "top k",
-			topK:    10,
-			wantErr: false,
+			name:        "top k",
+			topK:        10,
+			temperature: 0.8,
+			wantErr:     false,
 		},
 		{
-			name:    "invalid top k negative",
-			topK:    -1,
-			wantErr: true,
+			name:        "invalid top k negative",
+			topK:        -1,
+			temperature: 0.8,
+			wantErr:     true,
 		},
 		{
-			name:    "top p",
-			topP:    0.9,
-			wantErr: false,
+			name:        "top p",
+			topP:        0.9,
+			temperature: 0.8,
+			wantErr:     false,
 		},
 		{
-			name:    "invalid top p negative",
-			topP:    -0.1,
-			wantErr: true,
+			name:        "invalid top p negative",
+			topP:        -0.1,
+			temperature: 0.8,
+			wantErr:     true,
 		},
 		{
-			name:    "invalid top p one",
-			topP:    1.0,
-			wantErr: true,
+			name:        "invalid top p one",
+			topP:        1.0,
+			temperature: 0.8,
+			wantErr:     true,
 		},
 		{
-			name:    "min p",
-			minP:    0.2,
-			wantErr: false,
+			name:        "min p",
+			minP:        0.2,
+			temperature: 0.8,
+			wantErr:     false,
 		},
 		{
-			name:    "invalid min p negative",
-			minP:    -0.1,
-			wantErr: true,
+			name:        "invalid min p negative",
+			minP:        -0.1,
+			temperature: 0.8,
+			wantErr:     true,
 		},
 		{
-			name:    "invalid min p one",
-			minP:    1.0,
-			wantErr: true,
-		},
-		{
-			name:    "seed",
-			seed:    42,
-			wantErr: true, // seed alone is not valid without other transforms
+			name:        "invalid min p one",
+			minP:        1.0,
+			temperature: 0.8,
+			wantErr:     true,
 		},
 		{
 			name:        "default values",
@@ -184,7 +171,7 @@ func TestNewSampler(t *testing.T) {
 			topP:        0.0,
 			minP:        0.0,
 			seed:        0,
-			wantErr:     true, // all zeroes means no transforms
+			wantErr:     false, // all zeroes means no transforms
 		},
 		{
 			name:        "all transforms",
@@ -216,7 +203,7 @@ func BenchmarkSample(b *testing.B) {
 	}
 
 	samplers := map[string]Sampler{
-		"Greedy":   Greedy(transforms...),
+		"Greedy":   Greedy(),
 		"Weighted": Weighted(nil, transforms...),
 	}
 

From 0c1041ad851d2ce3dc4d74d5fedb1af759956688 Mon Sep 17 00:00:00 2001
From: Bruce MacDonald 
Date: Thu, 27 Feb 2025 16:41:20 -0800
Subject: [PATCH 18/20] runner: default to greedy sampler for performance
 (#9407)

As are adding support for weighted sampling we have seen some performance
regressions, bypassing the sampler logic for now and defaulting to greedy
until we can benchmark the new sampler logic.
---
 runner/ollamarunner/runner.go | 14 +-------------
 1 file changed, 1 insertion(+), 13 deletions(-)

diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go
index b39d747f..6b4c7be0 100644
--- a/runner/ollamarunner/runner.go
+++ b/runner/ollamarunner/runner.go
@@ -575,23 +575,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	sampler, err := sample.NewSampler(
-		req.Temperature,
-		req.TopK,
-		req.TopP,
-		req.MinP,
-		req.Seed,
-	)
-	if err != nil {
-		http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
-		return
-	}
-
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
 		numPredict: req.NumPredict,
 		stop:       req.Stop,
 		numKeep:    int32(req.NumKeep),
-		sampler:    sampler,
+		sampler:    sample.Greedy(), // TODO: add support for different samplers when performance is optimized
 		embedding:  false,
 	})
 	if err != nil {

From 2099e2d267735042e17a78981a6992138c86572e Mon Sep 17 00:00:00 2001
From: Blake Mizerany 
Date: Thu, 27 Feb 2025 19:22:26 -0800
Subject: [PATCH 19/20] CONTRIBUTING: provide clarity on good commit messages,
 and bad (#9405)

Also, our commit messages have been getting better, but we can do
better, and be more consistent. This adds more clarity on how to write
commit messages and provides examples of good and bad messages.

Also, our contributing guide was lacking helpful guidance on how to
start change proposals. This commit adds the start of that section.

Soon, we should add a proposal template to the issue tracker with a link
back to the proposal section, which should also be expanded upon.
---
 CONTRIBUTING.md | 63 ++++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 57 insertions(+), 6 deletions(-)

diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index f003a69d..f040b9fd 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -6,8 +6,6 @@ Thank you for your interest in contributing to Ollama! Here are a few guidelines
 
 See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
 
-## Pull requests
-
 ### Ideal issues
 
 * [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
@@ -26,11 +24,64 @@ See the [development documentation](./docs/development.md) for instructions on h
 * Changes that add significant friction to the user experience
 * Changes that create a large future maintenance burden for maintainers and contributors
 
-### Best practices
+## Proposing a (non-trivial) change
 
-* Commit messages: please leave both a title and a description in your commit messages. The title should be a short summary of the changes, with a leading word that explains the section of the code being changed (e.g. `api: fix parsing of prompt field`) . In the description, leave a short 2-3 sentences that explain more about the change and its impact.
-* Tests: please add test coverage to changes where possible.
-* Minimize dependencies: avoid adding new dependencies unless absolutely necessary.
+> By "non-trivial", we mean a change that is not a bug fix or small
+> documentation update. If you are unsure, please ask us on our [Discord
+> server](https://discord.gg/ollama).
+
+Before opening a non-trivial Pull Request, please open an issue to discuss the change and
+get feedback from the maintainers. This helps us understand the context of the
+change and how it fits into Ollama's roadmap and prevents us from duplicating
+work or you from spending time on a change that we may not be able to accept.
+
+Tips for proposals:
+
+* Explain the problem you are trying to solve, not what you are trying to do.
+* Explain why the change is important.
+* Explain how the change will be used.
+* Explain how the change will be tested.
+
+Additionally, for bonus points: Provide draft documentation you would expect to
+see if the change were accepted.
+
+## Pull requests
+
+**Commit messages**
+
+The title should look like:
+
+   : 
+
+The package is the most affected Go package. If the change does not affect Go
+code, then use the directory name instead. Changes to a single well-known
+file in the root directory may use the file name.
+
+The short description should start with a lowercase letter and be a
+continuation of the sentence:
+
+      "This changes Ollama to..."
+
+Examples:
+
+      llm/backend/mlx: support the llama architecture
+      CONTRIBUTING: provide clairity on good commit messages, and bad
+
+Bad Examples:
+
+      feat: add more emoji
+      fix: was not using famous web framework
+      chore: generify code
+
+**Tests**
+
+Please include tests. Strive to test behavior, not implementation.
+
+**New dependencies**
+
+Dependencies should be added sparingly. If you are adding a new dependency,
+please explain why it is necessary and what other ways you attempted that
+did not work without it.
 
 ## Need help?
 

From 98d44fa39d22c9c6f86fb964dd3bb13a38356371 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Thu, 27 Feb 2025 19:30:32 -0800
Subject: [PATCH 20/20] llama: add phi4 mini support (#9403)

---
 llama/llama.cpp/include/llama.h           |  1 +
 llama/llama.cpp/src/llama-model.cpp       | 10 ++-
 llama/llama.cpp/src/llama-vocab.cpp       | 11 ++++
 llama/patches/0019-add-phi4-support.patch | 80 +++++++++++++++++++++++
 4 files changed, 99 insertions(+), 3 deletions(-)
 create mode 100644 llama/patches/0019-add-phi4-support.patch

diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h
index cc948005..16774711 100644
--- a/llama/llama.cpp/include/llama.h
+++ b/llama/llama.cpp/include/llama.h
@@ -105,6 +105,7 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_CHAMELEON      = 26,
         LLAMA_VOCAB_PRE_TYPE_MINERVA        = 27,
         LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM  = 28,
+        LLAMA_VOCAB_PRE_TYPE_GPT4O          = 29,
     };
 
     enum llama_rope_type {
diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp
index 21819080..ab1a07d1 100644
--- a/llama/llama.cpp/src/llama-model.cpp
+++ b/llama/llama.cpp/src/llama-model.cpp
@@ -2283,7 +2283,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                     // output
                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
-                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
 
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
@@ -2298,8 +2302,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
                         layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
 
-                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
-                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
                     }
                 } break;
             case LLM_ARCH_PHIMOE:
diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp
index 1ca827eb..c7ff28be 100644
--- a/llama/llama.cpp/src/llama-vocab.cpp
+++ b/llama/llama.cpp/src/llama-vocab.cpp
@@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
                     "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
                 };
                 break;
+            case LLAMA_VOCAB_PRE_TYPE_GPT4O:
+                // original regex from tokenizer.json
+                // [^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+
+                regex_exprs = {
+                    "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+                };
+                break;
             default:
                 // default regex for BPE tokenization pre-processing
                 regex_exprs = {
@@ -1583,6 +1590,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             } else if (
                 tokenizer_pre == "megrez") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+            } else if (
+                tokenizer_pre == "gpt-4o") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
+                clean_spaces = false;
             } else {
                 LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
                 pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
diff --git a/llama/patches/0019-add-phi4-support.patch b/llama/patches/0019-add-phi4-support.patch
new file mode 100644
index 00000000..1cdc8171
--- /dev/null
+++ b/llama/patches/0019-add-phi4-support.patch
@@ -0,0 +1,80 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: jmorganca 
+Date: Thu, 27 Feb 2025 15:12:26 -0800
+Subject: [PATCH] add phi4 support
+
+---
+ include/llama.h     |  1 +
+ src/llama-model.cpp | 10 +++++++---
+ src/llama-vocab.cpp | 11 +++++++++++
+ 3 files changed, 19 insertions(+), 3 deletions(-)
+
+diff --git a/include/llama.h b/include/llama.h
+index cc948005..16774711 100644
+--- a/include/llama.h
++++ b/include/llama.h
+@@ -105,6 +105,7 @@ extern "C" {
+         LLAMA_VOCAB_PRE_TYPE_CHAMELEON      = 26,
+         LLAMA_VOCAB_PRE_TYPE_MINERVA        = 27,
+         LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM  = 28,
++        LLAMA_VOCAB_PRE_TYPE_GPT4O          = 29,
+     };
+ 
+     enum llama_rope_type {
+diff --git a/src/llama-model.cpp b/src/llama-model.cpp
+index 21819080..ab1a07d1 100644
+--- a/src/llama-model.cpp
++++ b/src/llama-model.cpp
+@@ -2283,7 +2283,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+ 
+                     // output
+                     output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+-                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
++                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
++                    // if output is NULL, init from the input tok embed
++                    if (output == NULL) {
++                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
++                    }
+ 
+                     for (int i = 0; i < n_layer; ++i) {
+                         auto & layer = layers[i];
+@@ -2298,8 +2302,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
+                         layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+                         layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
+ 
+-                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+-                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
++                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
++                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                     }
+                 } break;
+             case LLM_ARCH_PHIMOE:
+diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
+index 1ca827eb..c7ff28be 100644
+--- a/src/llama-vocab.cpp
++++ b/src/llama-vocab.cpp
+@@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
+                     "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                 };
+                 break;
++            case LLAMA_VOCAB_PRE_TYPE_GPT4O:
++                // original regex from tokenizer.json
++                // [^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+
++                regex_exprs = {
++                    "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
++                };
++                break;
+             default:
+                 // default regex for BPE tokenization pre-processing
+                 regex_exprs = {
+@@ -1583,6 +1590,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+             } else if (
+                 tokenizer_pre == "megrez") {
+                 pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
++            } else if (
++                tokenizer_pre == "gpt-4o") {
++                pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
++                clean_spaces = false;
+             } else {
+                 LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+                 pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;